显然因为我不会数学,所以这篇文章会非常“ 感性 ”。
题目
将两个多项式乘起来,即求 (f*g=h) 。多项式的项数 (nle10^5)。
FFT
前置知识
复数
复数是指形如 (x+yi) 的数,高中会教。
它的四则运算法则是这样的:(令 (p,q) 为两个复数)
除法用不着。
所以代码是这样的:(建议不要用 STL 的复数,常数巨大)
struct mle{lod x,y;}a[n7],b[n7];
mle operator + (mle p,mle q){return (mle){p.x+q.x,p.y+q.y};}
mle operator - (mle p,mle q){return (mle){p.x-q.x,p.y-q.y};}
mle operator * (mle p,mle q){return (mle){p.x*q.x-p.y*q.y,p.x*q.y+p.y*q.x};}
多项式运算
略。
点值表达法
选出 ((2n-1)) 个互不相同的横坐标 (x_i) ,代入 (f) 与 (g) 中,得到很多个 (fy_i,gy_i),而 ((x_i,fy_i)) 就是 (f) 的点值表达式, ((x_i,gy_i)) 就是 (g) 的点值表达式。神奇的事实是, ((x_i,fy_i imes gy_i)) 就是 (h) 的点值表达式!
!所以 FFT 的思想就是,将 (f) 和 (g) 转换成点值表达法,然后相乘得到 (h),最后再化为系数表示法(普通多项式的表示)。
其中,转化为点值表达法的步骤叫做 DFT,化回来的步骤叫做 IDFT。
单位根
有一个神奇的东西叫做单位根(复数),满足 (w^n=1) 的 (w) 被称作 (n) 次单位根。(事实上应该是 (omega),但是写起来太麻烦了就用 (w) 了)
经过推导,如果将所有的单位根排列,第 (k) 个 (n) 次单位根 (Large w_k=e^{ifrac{2kpi}{n}})
设 (n) 是正偶数,且 (m) 是 (n) 的一半,那么有 ((w_n^k)^2=w_m^k) 以及 (w_n^{m+k}=-w-n^k)。这两个等式就是法术,下文会用到。
算法
DFT(转成点值表示)
首先为了方便,我们把多项式变为 (n=(2^k-1)) 次多项式(不足的补系数 (0)),而且它大于原本的 (f) 和 (g) 的项数之和。这样原来的 ((2n-1)) 就不大于的 (n) ,方便统计。
然后我们想知道点值表达,所以我们需要代入一些 (x),并求出 (f(x)) 的值。
我们选择把魔幻的单位根, (w_n^0,w_n^1...w_n^{n-1}) 代入。
好,我们怎么求 (f(x))?先变变形式:
其中 (f,f_0,f_1) 并不是一样的,注意,(f_0) 的系数依次为 (a_0,a_2,a_4...) ,(f_1) 为 (a_1,a_3,a_5...)
运用着前文提到的法术,仍然设 (m=frac{n}{2}),对于 (k<m) 有
那么对于 (kge m) 呢?我们可以说它是 ((k+m)) 且 (k<m) 。继续用法术:
哇塞,一个是 (f_0+f_1),另一个是 (f_0-f_1)。
于是我们只要求出 (w_n^0sim w_n^m) 就可以知道剩下的了。而 (f_0,f_1) 可以递归求。
IDFT(化回来)
不会。但是代码和 DFT 是基本一样的。
实现
首先是直观但是常数大的递归版:
void FFT(mle *c,int len,bool sys){
//part 1
if(len==1)return;
mle zuo[(len>>1)+1],you[(len>>1)+1];
for(int i=0;i<=len;i=i+2)zuo[i>>1]=c[i],you[i>>1]=c[i+1];
FFT(zuo,len>>1,sys),FFT(you,len>>1,sys);
//part 2
lod tnp=2.0*pie/len;int wal=len>>1;
mle ori=(mle){cos(tnp),(sys?1:-1)*sin(tnp)},z=(mle){1,0};
//part 3
rep(i,0,wal-1){
c[i]=zuo[i]+z*you[i];
c[i+wal]=zuo[i]-z*you[i];
z=z*ori;
}
}
PART:
-
把系数分为两个部分,然后依次递归。
-
计算。其中 (w_n^0=1,w_n^k=w_n^{k-1}*ori)。
其中 DFT 时 (ori=cos{frac{2pi}{n}+sinfrac{2pi}{n}}),IDFT 时 (ori=cos{frac{2pi}{n}-sinfrac{2pi}{n}})。以及 (wal) 就是 (frac{n}{2})。
-
做前文的事情。运用 (w_n^k=w_n^{k-1}*ori)。
注意,DFT 的时候把原本装系数的数组变成了现在装 (f(w)) 的数组。
于是你就轻松有了 66分。毫无疑问,常数太大了!
优化
对于第 (x) 个系数 (a_x) ,它的路径是怎么样的?
0 1 2 3 4 5 6 7
0 2 4 6,1 3 5 7
0 4,2 6,1 5,3 7
0,4,2,6,1,5,3,7
你会发现,
如果把 0,4,2,6,1,5,3,7
,
每一个数都转为二进制 000, 100, 010, 110, 001, 101, 011, 111
,
再每一个二进制反过来 000, 001, 010, 011, 100, 101, 110, 111
,
最后化为十进制 0,1,2,3,4,5,6,7
。哦豁!
所以可以快速求得递归的底层是怎么样的,然后我们模拟递归,枚举长度(1,2,4,8……),然后把一段长度的合并。
但是又怎么样求二进制反转呢?
(color{red}Huge 待填!)
顺便提一句小优化,因为复数乘法常数大,所以一般在 (f=f_0pm f_1)是这样写:
rep(i,0,wal-1){
mle mul=z*you[i];
c[i]=zuo[i]+mul;
c[i+wal]=zuo[i]-mul;
z=z*ori;
}
最后代码:
也是待填
FFT代码
#include<bits/stdc++.h>
#define rep(i,x,y) for(int i=x;i<=y;++i)
#define lod double
using namespace std;
const int n7=3012345;
const lod pie=acos(-1);
int n,m,rv[n7];
struct mle{lod x,y;}a[n7],b[n7];
mle operator + (mle p,mle q){return (mle){p.x+q.x,p.y+q.y};}
mle operator - (mle p,mle q){return (mle){p.x-q.x,p.y-q.y};}
mle operator * (mle p,mle q){return (mle){p.x*q.x-p.y*q.y,p.x*q.y+p.y*q.x};}
int rd(){
int shu=0;char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))shu=(shu<<1)+(shu<<3)+ch-'0',ch=getchar();
return shu;
}
void FFT(mle *c,bool sys){
rep(i,0,n-1)if(i<rv[i])swap(c[i],c[ rv[i] ]);
for(int len=2;len<=n;len<<=1){
mle ori=(mle){cos(2*pie/len),(sys?1:-1)*sin(2*pie/len)};
int le=(len>>1);
for(int i=0;i<n;i+=len){
mle z=(mle){1,0};
rep(j,i,i+le-1){
mle tmp=z*c[le+j];
c[le+j]=c[j]-tmp;
c[j]=c[j]+tmp;
z=z*ori;
}
}
}
}
int main(){
n=rd(),m=rd();
rep(i,0,n)a[i].x=rd();
rep(i,0,m)b[i].x=rd();
m=m+n,n=1;
while(n<=m)n=n<<1;
rep(i,0,n-1){
rv[i]=(rv[i>>1]>>1);
if(i&1)rv[i]=rv[i]|(n>>1);
}
FFT(a,1),FFT(b,1);
rep(i,0,n)a[i]=a[i]*b[i];
FFT(a,0);
rep(i,0,m)printf("%d ",(int)(a[i].x/n+0.5));
return 0;
}
NTT代码
#include<bits/stdc++.h>
#define rep(i,x,y) for(int i=x;i<=y;++i)
#define lon long long
using namespace std;
const int n7=3012345;const lon mo=998244353;
int n,m,rv[n7];lon a[n7],b[n7];
int rd(){
int shu=0;char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))shu=(shu<<1)+(shu<<3)+ch-'0',ch=getchar();
return shu;
}
lon Dpow(lon p,lon q){
lon tot=1;
while(q){
if(q&1)tot=tot*p%mo;
p=p*p%mo,q=q>>1;
}
return tot;
}
void NTT(lon *c,bool sys){
rep(i,0,n-1)if(i<rv[i])swap(c[i],c[ rv[i] ]);
for(int len=2;len<=n;len<<=1){
lon ori=Dpow(sys?3:332748118,(mo-1)/len);
int le=(len>>1);
for(int i=0;i<n;i+=len){
lon z=1;
rep(j,i,i+le-1){
lon tmp=z*c[le+j]%mo;
c[le+j]=(c[j]-tmp+mo)%mo;
c[j]=(c[j]+tmp)%mo;
z=z*ori%mo;
}
}
}
}
int main(){
n=rd(),m=rd();
rep(i,0,n)a[i]=rd();
rep(i,0,m)b[i]=rd();
m=m+n,n=1;
while(n<=m)n=n<<1;
rep(i,0,n-1){
rv[i]=(rv[i>>1]>>1);
if(i&1)rv[i]=rv[i]|(n>>1);
}
NTT(a,1),NTT(b,1);
rep(i,0,n)a[i]=a[i]*b[i]%mo;
NTT(a,0);
lon inv=Dpow(n,mo-2);
rep(i,0,n)a[i]=a[i]*inv%mo;
rep(i,0,m)printf("%lld ",a[i]);
return 0;
}