首页 > 代码库 > hdu 4906 3-idiots fft

hdu 4906 3-idiots fft

题目链接

 

n个火柴棍取3个, 问能组成三角形的概率是多少。 kuangbin大神的博客写的很详细了..http://www.cnblogs.com/kuangbin/archive/2013/07/24/3210565.html

 

注意long long什么的就没问题了。

#include <bits/stdc++.h>using namespace std;#define ll long long#define mem(a) memset(a, 0, sizeof(a))typedef complex <double> cmx;const double PI = acos(-1.0);const int maxn = 400005;cmx x[maxn];int a[maxn/4];ll num[maxn];void change(cmx x[], int len) {    int i, j, k;    for(i = 1, j = len/2; i < len - 1; i++) {        if(i < j)            swap(x[i], x[j]);        k = len / 2;        while(j >= k) {            j -= k;            k /= 2;        }        if(j < k)            j += k;    }}void fft(cmx x[], int len, int on) {    change(x, len);    for(int i = 2; i <= len; i <<= 1) {        cmx wn(cos(-on * 2 * PI/i), sin(-on * 2 * PI/i));        for(int j = 0; j < len; j += i) {            cmx w(1, 0);            for(int k = j; k < j + i/2; k++) {                cmx u = x[k];                cmx v = x[k + i/2]*w;                x[k] = u + v;                x[k+i/2] = u - v;                w *= wn;            }        }    }    if(on == -1) {        for(int i = 0; i < len; i++)            x[i] /= len;    }}int main(){    int t, n;    cin>>t;    while (t--) {        cin>>n;        mem(num);        int maxx = 0;        for (int i = 0; i < n; i++) {            scanf("%d", a + i);            num[a[i]]++;            maxx = max(maxx, a[i]);        }        sort(a, a + n);        int len = 1;        maxx++;        while (len < 2*maxx) {            len <<= 1;        }        for (int i = 0; i < maxx; i++) {            x[i] = cmx(num[i], 0);        }        for (int i = maxx; i < len; i++) {            x[i] = cmx(0, 0);        }        fft(x, len, 1);        for (int i = 0; i < len; i++) {            x[i] *= x[i];        }        fft(x, len, -1);        for (int i = 0; i < len; i++) {            num[i] = (ll)(x[i].real()+0.5);        }        for (int i = 0; i < n; i++) {            num[a[i]+a[i]]--;        }        for (int i = 0; i < len; i++) {            num[i] /= 2;        }        for (int i = 1; i < len; i++) {            num[i] += num[i-1];        }        ll ans = 0;        for (int i = 0; i < n; i++) {            ans += num[len-1] - num[a[i]];            ans -= 1LL * (n-i-1) * i;            ans -= 1LL * (n-i-1) * (n-i-2) / 2;        }        ans -= 1LL * n * (n-1);        ll sum = 1LL * n * (n-1) * (n-2) / 6;        printf("%.7f\n", 1.0*ans/sum);    }    return 0;}

 

hdu 4906 3-idiots fft