首页 > 代码库 > bzoj3992 [SDOI2015]序列统计

bzoj3992 [SDOI2015]序列统计

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=3992

【题解】

很容易得到一个dp但是复杂度不对

我们想到用原根把乘法改成加法。

然后a1a2...an=g^(b1+b2+...+bn)

我们找到g^k=x,那么就有b1+b2+...+bn=x(mod (m-1))(m-1就是phi(m))

考虑生成函数,那么即为生成函数的n次方 mod x^(m-1)中,k次项的系数。

注意这里的mod是要把后面半部分移到前面的。

这样就可以FFT了,复杂度O(mlognlogm)

技术分享
# include <stdio.h>
# include <string.h>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int M = 8000 + 10, N = 2e5 + 10;
const int mod = 1004535809;
const int G = 3;

# define RG register
# define ST static

int n, m, X, S;
int fp[M];

struct pa {
    int a[N];
} A, ans;

inline int pwr(int a, int b, int P) {
    int ret = 1;
    while(b) {
        if(b&1) ret = 1ll * ret * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ret;
}

int t[2333], tn=0;
inline int getprt(int m) {
    tn = 0;
    for (int i=2; i<m-1; ++i) 
        if((m-1)%i==0) t[++tn] = i;
    for (int i=2; ; ++i) {
        bool ok = 1;
        for (int j=1; j<=tn; ++j) 
            if(pwr(i, t[j], m) == 1) {
                ok = 0;
                break;
            }
        if(ok) return i;
    }
    return -1;
}


namespace NTT {
    const int M = 2e5 + 10;
    int n, w[2][M], lst[M], invn;
    inline void init(int _n) {
        n = 1;
        while(n < _n) n <<= 1;
        w[0][0] = 1, w[1][0] = 1;
        int g = pwr(G, (mod-1)/n, mod), invg = pwr(g, mod-2, mod);
        for (int i=1; i<n; ++i) w[0][i] = 1ll * w[0][i-1] * g % mod, w[1][i] = 1ll * w[1][i-1] * invg % mod;
        int len = 0;
        while((1<<len) < n) ++len;
        for (int i=0; i<n; ++i) {
            int t = 0;
            for (int j=0; j<len; ++j) if(i&(1<<j)) t |= (1<<(len-j-1));
            lst[i] = t;
        }
        invn = pwr(n, mod-2, mod);
    }
    inline void DFT(int *a, int op) {
        int *o = w[op];
        for (int i=0; i<n; ++i) if(i < lst[i]) swap(a[i], a[lst[i]]);
        for (int len=2; len<=n; len<<=1) {
            int m = (len>>1);
            for (int *p = a; p != a+n; p += len) {
                for (int k=0; k<m; ++k) {
                    int t = 1ll * o[n/len*k] * p[k+m] % mod;
                    p[k+m] = p[k] - t; if(p[k+m] < 0) p[k+m] += mod;
                    p[k] = p[k] + t; if(p[k] >= mod) p[k] -= mod;                    
                }
            }
        }
        if(op) {
            for (int i=0; i<n; ++i) a[i] = 1ll * a[i] * invn % mod;
        }
    }
    
    inline void mul(int *x, pa A, pa B, int Mod) {
        DFT(A.a, 0); DFT(B.a, 0);
        for (int i=0; i<n; ++i) A.a[i] = 1ll * A.a[i] * B.a[i] % mod;
        DFT(A.a, 1);
        for (int i=0; i<n; ++i) x[i] = 0;
        for (int i=0; i<n; ++i) {
            int np = i % Mod;
            x[np] = x[np] + A.a[i];
            if(x[np] >= mod) x[np] -= mod;
        }
    }
}


int main() {
    scanf("%d%d%d%d", &n, &m, &X, &S);
    int gg = getprt(m), sum = 1;
    for (int i=0; i<m-1; ++i) {
        fp[sum] = i;
        sum = 1ll * sum * gg % m;
    }
    for (int i=1, pt; i<=S; ++i) {
        scanf("%d", &pt);
        if(pt) A.a[fp[pt]] = ans.a[fp[pt]] = 1;
    }
    
    
    NTT::init(m+m);    
    
    --n;
    
    while(n) {
        if(n&1) NTT::mul(ans.a, ans, A, m-1);
        NTT::mul(A.a, A, A, m-1);
        n >>= 1;
    }
    
//    printf("%d\n", fp[X]);
    
    printf("%d\n", ans.a[fp[X]]);
    return 0;
}
View Code

 

bzoj3992 [SDOI2015]序列统计