A * B Problem Plus
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)
Total Submission(s): 9413 Accepted Submission(s): 1468
Note: the length of each integer will not exceed 50000.
就一个高精度乘法 FFT加速。
最近正好要捡起fft,就顺便整理了模板。
FFT的原理还是算法导论靠谱,没有那么艰深难懂,就涉及怎么进行FFT和FFT需要的原理和定理。
看看算法导论里FFT的部分,一定要读到迭代实现那部分!!
看了好久求和引理,才发觉他是为了保证$w_n^k$与$w_n^{k+2/h}$的对称性(即$w_n^{k+2/h}=-w_n^k$)的,这个引理是必要的。
对于多项式序列,我们可以用两个O(nlgn)(n>max(len1,len2)*2)的FFT将其系数表示转化为点值表示(DFT),然后用O(n) 相乘,接着用FFT把结果的点值表示变为系数表示(IDFT),总体算起来是3O(nlgn)+O(n),即O(nlgn)的时间复杂度。比O(n^2)好多了。
以下是学习的两个版本。

1 #include<bits/stdc++.h> 2 #define clr(x) memset(x,0,sizeof(x)) 3 #define clr_1(x) memset(x,-1,sizeof(x)) 4 #define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x)) 5 #define LL long long 6 #define mod 1000000007 7 #define PI 3.1415926535 8 using namespace std; 9 char s1[200010],s2[200010]; 10 int a[200010],b[200010]; 11 //复数序列结构体 12 struct complexed 13 { 14 double r,i; 15 complexed(double _r=0.0,double _i=0.0) 16 { 17 r=_r; 18 i=_i; 19 } 20 complexed operator +(complexed b) 21 { 22 return complexed(r+b.r,i+b.i); 23 } 24 complexed operator -(complexed b) 25 { 26 return complexed(r-b.r,i-b.i); 27 } 28 complexed operator *(complexed b) 29 { 30 return complexed(r*b.r-i*b.i,i*b.r+r*b.i); 31 } 32 }num1,num2; 33 vector<complexed> multi1,multi2; 34 inline int max(int a,int b) 35 { 36 return a>b?a:b; 37 } 38 //并将长度变为2…^(k+1) 39 void changelen(int &len) 40 { 41 int mul=1; 42 while(mul<len) 43 mul<<=1; 44 mul<<=1; 45 len=mul; 46 return ; 47 } 48 //将整数序列复制到复数序列中 49 void copyed(int *a,vector<complexed> &multi,int len) 50 { 51 multi.resize(len); 52 for(int i=0;i<len;i++) 53 multi[i]=(complexed){a[i],0}; 54 return; 55 } 56 //DFT的话on=1,IDFT on=-1; 57 void fft(vector<complexed> &multi,int len,int on) 58 { 59 complexed wn,w,u,t; 60 //wn,w,u,t如算法导论中所示 61 vector<complexed> ans; 62 ans.resize(len); 63 //ans存每次操作计算后的y,最后再作为下次的multi。 64 for(int h=len/2;h>=1;h>>=1) 65 { 66 wn=(complexed){cos(2*on*PI/(len/h)),sin(2*on*PI/(len/h))}; 67 for(int i=0;i<h;i++) 68 { 69 w=(complexed){1,0}; 70 for(int j=0;j<len/h/2;j++) 71 { 72 //蝴蝶操作 73 u=multi[i+2*h*j]; 74 t=multi[i+2*h*j+h]*w; 75 ans[i+h*j]=u+t; 76 ans[i+h*j+len/2]=u-t; 77 w=w*wn; 78 } 79 } 80 //ans作为下次计算的multi 81 multi=ans; 82 } 83 //IDFT每个元素都得除以n 84 if(on==-1) 85 for(int i=0;i<len;i++) 86 multi[i].r/=len; 87 return ; 88 } 89 int main() 90 { 91 int len1,len2,len; 92 while(scanf("%s%s",s1,s2)!=EOF) 93 { 94 len1=strlen(s1); 95 len2=strlen(s2); 96 clr(a); 97 clr(b); 98 for(int i=0;i<len1;i++) 99 { 100 a[len1-i-1]=s1[i]-'0'; 101 } 102 for(int i=0;i<len2;i++) 103 { 104 b[len2-i-1]=s2[i]-'0'; 105 } 106 len=max(len1,len2); 107 //取长度较长者作为长度,并将长度变为2…^(k+1) 108 changelen(len); 109 //将两个整数序列复制到复数序列中 110 copyed(a,multi1,len); 111 copyed(b,multi2,len); 112 //对两个复数序列进行DFT,变为点值表示 113 fft(multi1,len,1); 114 fft(multi2,len,1); 115 //对应点点值相乘 116 for(int i=0;i<len;i++) 117 multi1[i]=multi1[i]*multi2[i]; 118 //将的出来的点值表示进行IDFT变为系数表示 119 fft(multi1,len,-1); 120 //四舍五入减小损失精度 121 for(int i=0;i<len;i++) 122 { 123 a[i]=(int)(multi1[i].r+0.5); 124 } 125 //进位 126 for(int i=0;i<len;i++) 127 { 128 a[i+1]=a[i+1]+a[i]/10; 129 a[i]%=10; 130 } 131 len=len1+len2-1; 132 //去掉前导0 133 while(a[len]<=0 && len>0) len--; 134 for(int i=len;i>=0;i--) 135 printf("%d",a[i]); 136 printf(" "); 137 } 138 return 0; 139 }

1 #include<bits/stdc++.h> 2 #define clr(x) memset(x,0,sizeof(x)) 3 #define clr_1(x) memset(x,-1,sizeof(x)) 4 #define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x)) 5 #define LL long long 6 #define mod 1000000007 7 #define PI 3.1415926535 8 using namespace std; 9 char s1[200010],s2[200010]; 10 int a[200010],b[200010]; 11 struct complexed 12 { 13 double r,i; 14 complexed(double _r=0.0,double _i=0.0) 15 { 16 r=_r; 17 i=_i; 18 } 19 complexed operator +(complexed b) 20 { 21 return complexed(r+b.r,i+b.i); 22 } 23 complexed operator -(complexed b) 24 { 25 return complexed(r-b.r,i-b.i); 26 } 27 complexed operator *(complexed b) 28 { 29 return complexed(r*b.r-i*b.i,i*b.r+r*b.i); 30 } 31 }num1,num2; 32 complexed multi1[200010<<2],multi2[200010<<2]; 33 inline int max(int a,int b) 34 { 35 return a>b?a:b; 36 } 37 void changelen(int &len) 38 { 39 int mul=1; 40 while(mul<len) 41 mul<<=1; 42 mul<<=1; 43 len=mul; 44 return ; 45 } 46 //将整数序列复制到复数序列中 47 void copyed(int *a,complexed *multi,int len) 48 { 49 for(int i=0;i<len;i++) 50 multi[i]=(complexed){a[i],0}; 51 return; 52 } 53 //位逆序变换 54 void bitchange(complexed *multi,int len) 55 { 56 int i,j,k; 57 for(i = 1, j = len/2;i < len-1; i++) 58 { 59 if(i < j)swap(multi[i],multi[j]); 60 k = len/2; 61 while( j >= k) 62 { 63 j -= k; 64 k /= 2; 65 } 66 if(j < k) j += k; 67 } 68 return ; 69 } 70 //DFT的话on=1,IDFT on=-1; 71 void fft(complexed *multi,int len,int on) 72 { 73 bitchange(multi,len);//位逆序置换 74 complexed wn,w,u,t;//如算法导论所示 75 for(int h=2;h<=len;h<<=1) 76 { 77 wn=(complexed){cos(2*on*PI/h),sin(2*on*PI/h)}; 78 for(int i=0;i<len;i+=h) 79 { 80 //蝴蝶操作 81 w=(complexed){1,0}; 82 for(int j=i;j<i+h/2;j++) 83 { 84 u=multi[j]; 85 t=multi[j+h/2]*w; 86 multi[j]=u+t; 87 multi[j+h/2]=u-t; 88 w=w*wn; 89 } 90 } 91 } 92 //IDFT每个元素都得除以n 93 if(on==-1) 94 for(int i=0;i<len;i++) 95 multi[i].r/=len; 96 return ; 97 } 98 void mul(int *a,int *b,int &len1,int &len2) 99 { 100 int len=max(len1,len2); 101 //取长度较长者作为长度,并将长度变为2…^(k+1) 102 changelen(len); 103 //将两个整数序列复制到复数序列中 104 copyed(a,multi1,len); 105 copyed(b,multi2,len); 106 //对两个复数序列进行DFT,变为点值表示 107 fft(multi1,len,1); 108 fft(multi2,len,1); 109 //对应点点值相乘 110 for(int i=0;i<len;i++) 111 multi1[i]=multi1[i]*multi2[i]; 112 //将的出来的点值表示进行IDFT变为系数表示 113 fft(multi1,len,-1); 114 //四舍五入减小损失精度 115 for(int i=0;i<len;i++) 116 { 117 a[i]=(int)(multi1[i].r+0.5); 118 } 119 while(len-1>0 && a[len-1]==0) 120 len--; 121 len1=len; 122 return ; 123 } 124 int main() 125 { 126 int len1,len2,len; 127 while(scanf("%s%s",s1,s2)!=EOF) 128 { 129 len1=strlen(s1); 130 len2=strlen(s2); 131 clr(a); 132 clr(b); 133 for(int i=0;i<len1;i++) 134 { 135 a[len1-i]=s1[i]-'0'; 136 } 137 for(int i=0;i<len2;i++) 138 { 139 b[len2-i]=s2[i]-'0'; 140 } 141 mul(a+1,b+1,len1,len2); 142 //进位 143 len=len1; 144 for(int i=1;i<len;i++) 145 { 146 a[i+1]=a[i+1]+a[i]/10; 147 a[i]%=10; 148 } 149 while(a[len]>9) 150 { 151 a[len+1]=a[len+1]+a[len]/10; 152 a[len]%=10; 153 len++; 154 } 155 for(int i=len;i>=1;i--) 156 printf("%d",a[i]); 157 printf(" "); 158 } 159 return 0; 160 }
后来看了ntt,小改了下原迭代实现的模板,实现了迭代实现的NTT模板:

1 #include<bits/stdc++.h> 2 #define clr(x) memset(x,0,sizeof(x)) 3 #define clr_1(x) memset(x,-1,sizeof(x)) 4 #define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x)) 5 #define LL long long 6 #define mod 1004535809 7 #define PI 3.1415926535 8 #define P 1004535809 9 #define G 3 10 using namespace std; 11 char s1[200010],s2[200010]; 12 LL a[200010],b[200010],c[200010]; 13 LL quick_pow(LL mul,LL n) 14 { 15 LL res=1; 16 mul=(mul%mod+mod)%mod; 17 while(n) 18 { 19 if(n%2) 20 res=res*mul%mod; 21 mul=mul*mul%mod; 22 n/=2; 23 } 24 return res; 25 } 26 inline int max(int a,int b) 27 { 28 return a>b?a:b; 29 } 30 void bitchange(LL *a,int len) 31 { 32 int i,j,k; 33 for(i = 1, j = len>>1;i < len-1; i++) 34 { 35 if(i < j)swap(a[i],a[j]); 36 k = len>>1; 37 while( j >= k) 38 { 39 j -= k; 40 k >>= 1; 41 } 42 if(j < k) j += k; 43 } 44 return ; 45 } 46 void changelen(int &len) 47 { 48 int mul=1; 49 while(mul<len) 50 mul<<=1; 51 mul<<=1; 52 len=mul; 53 return ; 54 } 55 //DFT的话on=1,IDFT on=-1; 56 void ntt(LL *a,int len,LL on) 57 { 58 bitchange(a,len);//位逆序置换 59 LL wn,w,u,t;//如算法导论所示 60 for(int h=2;h<=len;h<<=1) 61 { 62 wn=quick_pow(G,(P-1)/h)%mod; 63 for(int i=0;i<len;i+=h) 64 { 65 //蝴蝶操作 66 w=1; 67 for(int j=i;j<i+h/2;j++) 68 { 69 u=a[j]%mod; 70 t=a[j+h/2]*w%mod; 71 a[j]=(u+t)%mod; 72 a[j+h/2]=(u-t+mod)%mod; 73 w=w*wn%mod; 74 } 75 } 76 } 77 //IDFT调换次序实现wn^-1的情况,并且乘以len的逆元 78 if(on==-1) 79 { 80 //k^0显然不调换次序,但是k^1与k^-1,k^2与k^-2.... k^len/2与k^-len/2 要调换次序 81 for(int i=1;i<len/2;i++) 82 swap(a[i],a[len-i]); 83 LL re=quick_pow(len,P-2); 84 for(int i=0;i<len;i++) 85 a[i]=a[i]*re%mod; 86 } 87 return ; 88 } 89 void mul(LL *a,LL *b,int &len1,int &len2) 90 { 91 int len=max(len1,len2); 92 //取长度较长者作为长度,并将长度变为2…^(k+1) 93 changelen(len); 94 //对两个整数序列进行DFT,变为点值表示 95 ntt(a,len,1); 96 ntt(b,len,1); 97 //对应点点值相乘 98 for(int i=0;i<len;i++) 99 a[i]=b[i]*a[i]%mod; 100 //将的出来的点值表示进行IDFT变为系数表示 101 ntt(a,len,-1); 102 while(len-1>0 && a[len-1]==0) 103 len--; 104 len1=len; 105 return ; 106 } 107 int main() 108 { 109 int len1,len2,len; 110 while(scanf("%s%s",s1,s2)!=EOF) 111 { 112 len1=strlen(s1); 113 len2=strlen(s2); 114 clr(a); 115 clr(b); 116 for(int i=0;i<len1;i++) 117 { 118 a[len1-i]=s1[i]-'0'; 119 } 120 for(int i=0;i<len2;i++) 121 { 122 b[len2-i]=s2[i]-'0'; 123 } 124 mul(a+1,b+1,len1,len2); 125 //进位 126 len=len1; 127 for(int i=1;i<len;i++) 128 { 129 a[i+1]=a[i+1]+a[i]/10; 130 a[i]%=10; 131 } 132 while(a[len]>9) 133 { 134 a[len+1]=a[len+1]+a[len]/10; 135 a[len]%=10; 136 len++; 137 } 138 for(int i=len;i>=1;i--) 139 printf("%lld",a[i]); 140 printf(" "); 141 } 142 return 0; 143 }
NTT需要爆搜下找到该质数的原根(这部分一般不写到代码里,一般是自己找出来以后再直接作为常量放在程序里,建议分解完P-1的质因数后去搜索快点,一般原根都不太大比较好搜)。在比赛中一般给出的质数P,P-1后一般是C*2^k的形式,才能支持2^k的分治。
学习资料推荐:http://blog.sina.com.cn/s/blog_7c4c33190102wht6.html 这个看下原理一类的,包括FFT的。其中笔者把(P-1)*2^m写错写成了P*2^m了。
代码以及等价性参考ACdreamer的代码:http://blog.csdn.net/acdreamers/article/details/39026505