首页 > 代码库 > 多项式乘法

多项式乘法

xsy上的NTT模板题,原来不知道FFT如何改成FFT,听了yww讲觉得它挺奇妙的,幸好在能接受的范围

 

 1 #include<stdio.h>
 2 #define maxn 131079LL
 3 //131072=2^17
 4 #define mod 998244353LL
 5 //998244353=119*2^23+1 g=3
 6 //or 1004535809=479*2^21+1 g=3
 7 #define g 3LL
 8 //(g^x)%mod=1 (only when x==mod-1)
 9 #define ll long long
10 ll a[maxn],b[maxn],A[maxn],N,logN,logi[20],wn[maxn];
11 ll pow(ll a,ll b){
12     if(b==1)return a;
13     ll ans=pow(a,b>>1);
14     ans=ans*ans%mod;
15     if(b&1)ans=ans*a%mod;
16     return ans;
17 }
18 void ntt(ll*a,ll on){
19     ll i,j,k,w,t,u;
20     for(i=0;i<N;i++){
21         k=i;
22         for(j=0;j<logN;j++){
23             logi[j]=k&1;
24             k>>=1;
25         }
26         k=0;
27         for(j=0;j<logN;j++)k=(k<<1)+logi[j];
28         A[k]=a[i];
29     }
30     if(on==-1){
31         for(i=2;i<=N;i<<=1)
32             wn[i]=pow(wn[i],mod-2);
33     }
34     for(i=2;i<=N;i<<=1){
35         for(j=0;j<N;j+=i){
36             w=1;
37             for(k=0;k<(i>>1);k++){
38                 t=w*A[j+k+(i>>1)]%mod;
39                 u=A[j+k];
40                 A[j+k]=(u+t)%mod;
41                 A[j+k+(i>>1)]=((u-t)%mod+mod)%mod;
42                 w=w*wn[i]%mod;
43             }
44         }
45     }
46     if(on==-1){
47         for(i=2;i<=N;i<<=1)
48             wn[i]=pow(wn[i],mod-2);
49         k=pow(N,mod-2);
50     }
51     for(i=0;i<N;i++){
52         a[i]=A[i];
53         if(on==-1)a[i]=a[i]*k%mod;
54     }
55 }
56 int main(){
57     ll n,m,i;
58     scanf("%lld%lld%lld",&n,&m,&i);
59     N=1;
60     logN=0;
61     n++;
62     m++;
63     while(N<n||N<m){
64         N<<=1;
65         logN++;
66     }
67     N<<=1;
68     logN++;
69     wn[N]=pow(g,(mod-1)/N);
70     for(i=N>>1;i>1;i>>=1)wn[i]=wn[i<<1]*wn[i<<1]%mod;
71     for(i=0;i<n;i++)scanf("%lld",a+i);
72     for(i=0;i<m;i++)scanf("%lld",b+i);
73     ntt(a,1);
74     ntt(b,1);
75     for(i=0;i<N;i++)a[i]=a[i]*b[i]%mod;
76     ntt(a,-1);
77     for(i=0;i<n+m-1;i++)printf("%lld ",a[i]);
78 }

 

多项式乘法