首页 > 代码库 > BZOJ 3992 [SDOI2015]序列统计

BZOJ 3992 [SDOI2015]序列统计

数列长度到了109,转移矩阵边长n到了8000,除了FFT还能怎么写??!!

当然,这题由于取模,必须用NTT.

同时由于取得是乘积,所以用m的原根来搞,每次NTT完了,把后面的部分加到前面去.

注意,X不会出现0,因此一旦S集合中出现0,删掉.原根判不了0.

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<string>
#include<cmath>
#include<ctime>
#include<algorithm>
#include<map>
#include<set>
#include<queue>
#include<iomanip>
using namespace std;
#define ll long long
#define db double 
#define up(i,j,n) for(ll i=j;i<=n;i++)
#define pii pair<ll,ll>
#define uint unsigned ll
#define FILE "dealing"
ll read(){
	ll x=0,f=1,ch=getchar();
	while(ch<‘0‘||ch>‘9‘){if(ch==‘-‘)f=-1;ch=getchar();}
	while(ch>=‘0‘&&ch<=‘9‘){x=(x<<1)+(x<<3)+ch-‘0‘;ch=getchar();}
	return x*f;
}
template<class T> bool cmax(T& a,T b){return a<b?a=b,true:false;}
template<class T> bool cmin(T& a,T b){return a>b?a=b,true:false;}
const ll maxn=400100,limit=128,inf=1000000000,r=3,mod=1004535809;
ll n,m,X,S,G;
ll a[maxn],b[maxn],id[maxn];
ll fast(ll a,ll b,ll mod){
	ll ans=1;
	while(b){
		if(b&1)ans=ans*a%mod;
		b>>=1;
		a=a*a%mod;
	}
	return ans;
}
void print(ll* a,ll len){
	up(i,0,len-1)printf("%lld ",a[i]);
	cout<<endl;
}
namespace prepare{//求M原根
	const ll maxn=8080;
	ll b[maxn],prime[maxn],tail,q[maxn],head;
	void getprime(){
		for(ll i=2;i<maxn;i++){
			if(!b[i])prime[++tail]=i;
			for(ll j=1;prime[j]*i<maxn&&j<=tail;j++){
				b[i*prime[j]]=1;
				if(i%prime[j]==0)break;
			}
		}
	}
	ll solve(ll N){
		getprime();ll p=N-1;N--;
		for(ll i=1;i<=tail;i++){
			if(N==1)break;
			if(N%prime[i]==0)q[++head]=prime[i];
			while(N%prime[i]==0)
				N/=prime[i];
		}
		for(ll i=2;i<=p;i++){
			bool flag=0;
			for(ll j=1;j<=head;j++)
				if(fast(i,p/(q[j]),p+1)==1)flag=1;
			if(!flag)return i;
		}
		return 0;
	}
};
namespace NTT{
	ll R[maxn],a[maxn],b[maxn],w[maxn];
	ll H,L;
	void NTT(ll* a,ll flag){
		for(ll i=0;i<L;i++)if(i<R[i])swap(a[i],a[R[i]]);
		for(ll len=2;len<=L;len<<=1){
			ll g=fast(r,(mod-1)/len,mod),l=len>>1;
			if(flag)g=fast(g,mod-2,mod);
			w[0]=1;up(i,1,l)w[i]=w[i-1]*g%mod;
			for(ll st=0;st<L;st+=len)
				for(ll k=0;k<l;k++){
					ll x=a[st+k],y=w[k]*a[st+k+l]%mod;
					a[st+k]=(x+y)%mod,a[st+k+l]=(x-y+mod)%mod;
				}
		}
		if(flag){
			ll inv=fast(L,mod-2,mod);
			up(i,0,L-1)a[i]=a[i]*inv%mod;
		}
	}
	void solve(ll* c,ll* d,ll n,ll m,ll* ch){
		n++,m++;
		up(i,0,n-1)a[i]=c[i];
		up(i,0,m-1)b[i]=d[i];
		for(H=0,L=1;L<n+m-1;H++)L<<=1;
		up(i,n,L)a[i]=0;
		up(i,m,L)b[i]=0;
		up(i,1,L)R[i]=(R[i>>1]>>1)|((i&1)<<(H-1));
		NTT(a,0);NTT(b,0);
		up(i,0,L-1)a[i]=a[i]*b[i]%mod;
		NTT(a,1);
		up(i,0,n+m-2)ch[i]=a[i];
	}
};
ll c[maxn],ans[maxn],tmp[maxn];
int main(){
	freopen(FILE".in","r",stdin);
	freopen(FILE".out","w",stdout);
	n=read();m=read();X=read(),S=read();
	up(i,1,S)a[i]=read();
	G=prepare::solve(m);
	ll w=1;
	up(i,0,m-1){
		id[w]=i;
		w=w*G%m;
	}
	up(i,1,S)if(a[i])a[i]=id[a[i]];X=id[X];
	up(i,1,S)if(a[i])c[a[i]]++;
	ans[0]=1;
	while(n){
		if(n&1){
			NTT::solve(ans,c,m,m,tmp);
			up(i,0,m-1)ans[i]=tmp[i];
			up(i,1,m)ans[i]=(ans[i]+tmp[i+m-1])%mod;
		}
		n>>=1;
		NTT::solve(c,c,m,m,tmp);
		up(i,0,m-1)c[i]=tmp[i];
		up(i,1,m)c[i]=(c[i]+tmp[i+m-1])%mod;
	}
	printf("%lld\n",ans[X]);
	return 0;
}

  

BZOJ 3992 [SDOI2015]序列统计