某知名选手:出多项式题的人就像在贩毒,做多项式的人就像在嗑药。
一直就想写关于嗑药的内容了,但是由于嗑药所需要的时间很久,而且我没有大块的时间来写一篇真正入门的东西,所以一直咕咕咕。
直到现在,为了自我复习整理一遍思路,写了一篇真正入门的FFT教程。
话不多说,直接进入正题。
一.所需前置芝士
1.多项式是啥?
形如$f(x)=ax^4+bx^3+cx^2+dx+e$这样的式子,即:$sum_{i=0}^{n}a_ix^i$叫做多项式。
定义:最高次项为n的多项式叫做n次多项式。
推论:由于常数项的存在,n次多项式最多有(n+1)项。
2.复数是啥?
正常高中所有的知识都是在实数的基础下尽心运算,但在算法中仅靠实数这些数远远不够。无法达到算法的目的。因此引入数域最大的复数。(注意,复数就是所说的虚数)。
复数的例子:实数中无法表示$sqrt{-1}$,而这个数在复数中是真实存在的。记作$i$;即:$i^2=-1$
设a,b是实数,那么形如$a+ib$的数叫复数。其中$i$被称为虚数单位,复数域是目前已知最大的域。
在复平面中,x代表实数,y轴(除原点外的点)代表虚数,从原点(0,0)到(a,b)的向量表示复数$a+bi$
模长:从原点(0,0)到点(a,b)的距离,即$sqrt{a^2+b^2}$。
幅角:假设以逆时针为正方向,从x轴正半轴到已知向量的转角的有向角叫做幅角。
复数的运算:
1.加法:实数部相加,虚数部相加。即:$(a+ib)+(c+id)=(a+c)+i(b+d)$
2.减法:实数部相减,虚数部相减。即:$(a+ib)-(c+id)=(a+c)-i(b+d)$
3.乘法:$(a+ib)*(c+id)=ac+iad+ibc+i^2bd=ac-bd+i(ad+bc)$
单位根:在复平面上,以原点为圆心,1为半径作圆,所得的圆叫单位圆。以圆点为起点,圆的n等分点为终点,做n个向量,设幅角为正且最小的向量对应的复数为$omega_n$,称为n次单位根。
注意,上文单位根中所说的n等分中的n必须是2的正整数次幂。
根据复数乘法的运算法则,其余n-1个复数为$omega_n^2,omega_n^3,ldots,omega_n^n$
单位根的性质:
1.$omega_n^0=omega_n^n=1$ 意义:在x实数正半轴上,长度为1。
2.根据复数的定义,我们将复数的实部和虚部作为向量进行运算,也就是说:$omega_n^k=cos heta+isin heta $,而$ heta=2*kfrac{2pi}{n}$,所以$omega_n^k=cos(kfrac{2pi}{n})+isin(kfrac{2pi}{n})$
3.若z的n次方为1,那么就叫z为n次单位根。
4.$omega_n^k*omega_n^k=(cos(kfrac{2pi}{n})+isin(kfrac{2pi}{n}))*(cos(kfrac{2pi}{n})+isin(kfrac{2pi}{n}))$
$=cos^2(kfrac{2pi}{n})-sin^2(kfrac{2pi}{n})+i(2cos(kfrac{2pi}{n})sin(kfrac{2pi}{n}))$
$=frac{1+cos (2kfrac{2pi}{n})}{2}-(frac{1-cos (2kfrac{2pi}{n})}{2})+isin(2kfrac{2pi}{n})$ (根据高中所学的三角降幂公式)
$=cos (2kfrac{2pi}{n})+isin(2kfrac{2pi}{n})$
$=omega_n^{2k}$
5.消去引理:$omega_n^k=omega_{2n}^{2k}$
证明:$omega_{2n}^{2k}=cos(2kfrac{2pi}{2n})+isin(2kfrac{2pi}{2n})=cos(kfrac{2pi}{n})+isin(kfrac{2pi}{n})=omega_n^k$
推论:$omega_n^k=omega_{dn}{dk}$
6.折半引理:$omega_{n}^{k+frac{n}{2}}=-omega_n^k$
证明:$omega_n^{frac{n}{2}}=cos(frac{n}{2}*frac{2pi}{n})+isin(frac{n}{2}*frac{2pi}{n})$
$=cos pi+isin pi=-1$
所以$omega_{n}^{k+frac{n}{2}}=omega_n^k *omega_n^{frac{n}{2}}=-omega_n^k$
二.快速傅里叶变换走起
1.点值表示法
我们知道,两个多项式相乘也就是卷积,正常来算肯定是$O(n^2)$的,而且乍眼一看似乎没有什么优化方法。但是,就是有人研究出了不损失正确性的$O(nlogn)$的算法,这个研究过程我们稍微提及一下。
首先,我们平常接触的形如$f(x)=ax^4+bx^3+cx^2+dx+e$的式子表示唯一一个多项式叫做系数表示法。其次,n个点确定唯一一个n-1次多项式,那么我们用n个点也可以唯一表示一个n-1次多项式,这种表示方法叫做点值表示法。
对于两个用点值表示法表示的多项式如:$A(x)=((x_0,y_a0),(x_1,y_a1),......,(x_n,y_an))$、$B(x)=((x_0,y_b0),(x_1,y_b1),......,(x_n,y_bn))$相乘,那么相乘结果的多项式的点值表示法就是$C(x)=((x_0,y_a0*y_b0),(x_1,y_a1*y_b1),......,(x_n,y_an*y_bn))$,而这样求得乘积的复杂度是$O(n)$的,原理利用一次函数或者二次函数自己体会就能弄明白。(注意,为了能确定乘积唯一一个多项式,我们定义n为乘积所得多项式$C()$的次数,这样可以保证得到n+1个点确定唯一一个n次多项式。也就是说,多项式$A()$和$B()$所取的点的个数要相等且等于$(两个多项式的次数和-2)$)。
那么问题转换为将多项式系数表示法转化成点值表示法。
朴素系数转点值的算法叫DFT(离散傅里叶变换),优化后为FFT(快速傅里叶变换),点值转系数的算法叫IDFT(离散傅里叶逆变换),优化后为IFFT(快速傅里叶逆变换)。
对于DFT,想必初中生都会,也就是O(n^2)的暴力取值然后带入多项式计算。至于IDFT?高斯消元也是可以做的。
那么FFT呢?接下来的部分便主要介绍FFT。
2.FFT快速傅里叶变换
还记得复数吗?不记得了?赶快去上面翻一翻复数的那些性质,我们在下文会经常的用到。
假设存在一个多项式:$A(x)=a_0x^0+a_1x^1+a_2x^2+......+a_{n-2}x^{n-2}+a_{n-1}x^{n-1}$
然后我们按照下标奇偶性分成两部分:$A(x)=(a_0x^0+a_2x^2+a_4x^4+......)+(a_1x^1+a_3x^3+a_5x^5+......)$
我们定义两个多项式,$A_1(x)=(a_0x^0+a_2x^1+a_4x^2+......)$,$A_2(x)=(a_1x^0+a_3x^1+a_5x^2+......)$ (注意:x的次数和$A()$中x的次数不一样)
我们根据初中知识可以知道$A()$和$A_1()$、$A_2()$的关系:$A(x)=A_1(x^2)+xA_2(x^2)$
我们将单位根$omega_n^k(k<frac{n}{2})$代入上面的式子,那么$A(omega_n^k)=A_1((omega_n^k)^2)+omega_n^kA_2((omega_n^k)^2)$
$=A_1(omega_n^{2k})+omega_n^kA_2(omega_n^{2k})$
然后将$omega_n^{k+frac{n}{2}}(k<frac{n}{2})$再一次代入上面的式子,我们可以推导:$A(omega_n^{k+frac{n}{2}})=A_1((omega_n^{k+frac{n}{2}})^2)+omega_n^{k+frac{n}{2}}A_2((omega_n^{k+frac{n}{2}})^2)$
$=A(omega_n^{k+frac{n}{2}})=A_1(omega_n^{2k+n})+omega_n^{k+frac{n}{2}}A_2(omega_n^{2k+n})$
$=A(omega_n^{k+frac{n}{2}})=A_1(omega_n^{2k}*omega_n^n)+omega_n^{k+frac{n}{2}}A_2(omega_n^{2k}*omega_n^n)$
$=A(omega_n^{k+frac{n}{2}})=A_1(omega_n^{2k})+omega_n^{k+frac{n}{2}}A_2(omega_n^{2k})$
$A(omega_n^{k+frac{n}{2}})=A_1(omega_n^{2k})-omega_n^{k}A_2(omega_n^{2k})$
发现了什么?以上两种取值带入该多项式后只有常数项不同,而这两种取值范围合起来正好是全域且两个域的大小相同。
也就是说,我们在计算n个不同$A()$的点值表达的时候,我们可以把问题缩小一半,而且这个问题可以递归去做(因为要求的是另外两个多项式)。于是得到了接近于$O(nlogn)$的获取n个两两不同的点值的算法。
等到递归到多项式仅仅有一个常数项时,我们无论给这个多项式什么参数,返回值都是这个常数。因此多项式项数为1的时候(次数为0)结束递归,返回这个常数。
具体写法参考一会出现的的代码。
FFT完结撒花~(逗你的,但FFT真的完了)
3.IFFT(快速傅里叶逆变换)
我们在做题的时候,很少会有人使用点值表示法来表示一个多项式(你可以试试在做二次函数抛物线的时候抛给老师一个点值表示法的多项式)。所以我们还需要一个告诉的算法,把点值表示法转换成系数表示法。这就是IFFT要做的。
我们假设:$(y_0,y_1,y_2,y_3,......,y_n)$是一个多项式$F()$在$(b_0,b_1,b_2,b_3,......,b_n)$处用FFT求出来的点值表示。其中多项式$F()$就是上文提到的多项式$A()$和多项式$B()$的乘积,而$b_i$其实就是多项式$F()$的n+1个系数。
有些本文上面的内容没有消化好的会说:$(y_0,y_1,y_2,y_3,......,y_n)$就是该多项式$F()$的系数。
这就大错特错了,希望这么想的人一定要好好阅读一下FFT的内容后再来学习IFFT。
但是----我们不妨就真的把$(y_0,y_1,y_2,y_3,......,y_n)$当作一个n次多项式$G()$的系数。 (注意,这里的多项式$G()$并不是多项式$A()$和多项式$B()$的乘积,而是新设的一个多项式)
我们假设,$G(k)=sum_{i=0}^{n}y_i(omega_n^{-k})^i$
$=sum_{i=0}^{n}(sum_{j=0}^{n}b_j(omega_n^i)^j)(omega_n^{-k})^i$
$=sum_{i=0}^{n}sum_{j=0}^{n}b_i(omega_n^j)^i(omega_n^{-k})^i$
$=sum_{j=0}^{n}b_jsum_{i=0}^{n}(omega_n^jomega_n^{-k})^i$
$=sum_{j=0}^{n}b_jsum_{i=0}^{n}(omega_n^{j-k})^i$
发现什么没有?没发现?那么请接着推式子。发现了?那么就请跳过这一段。
我们设$S(x)=sum_{i=0}^{n}x^i$-------------------①式
将方程两侧同时乘x,那么方程变为:$xS(x)=sum_{i=0}^{n}x^{i+1}$-------------------②式
用②式减①式,得到:$(x-1)S(x)=x^{n+1}-1$
也就是说,我们得到:$S(x)=frac{x^{n+1}-1}{x-1}$
我们将$omega_n^k$代入上面的式子,式子变成了$S(omega_n^k)=frac{(omega_n^k)^{n+1}-1}{omega_n^k-1}$
然后我们发现了一个事情,那就是这个式子如果n+1变成n就会变得更加美妙又间接。
根据点值表示法的定义,我们清楚:如果(n+1+k)个点都在某个n次多项式上,那么这(n+1+k)个点也可以确定一个n次多项式。($k>0$)
因为n是2的正整数幂,所以我们在选取n的值的时候要满足,n要严格大于所求乘积的多项式的次数+1,这样就可以保证选取n-1个在该多项式上的点就可以唯一确定该多项式。
这么做有什么用呢?我们可以利用数学归纳法的思想把上文中所有提到的n都-1,也就是说,n变成n-1,而n+1可以变成n。
再来观察这个式子:$S(omega_n^k)=frac{(omega_n^k)^{n}-1}{omega_n^k-1}$
$=frac{(omega_n^n)^k-1}{omega_n^k-1}$
$=frac{1-1}{omega_n^k-1}$
因为$omega_n^k$中的k原来取遍0~n,那么现在k只可以取遍0~n-1。
所以当k不等于0的时候$S(omega_n^k)=frac{1-1}{omega_n^k-1}=0$
那么当k等于0的时候呢?显然,$S(omega_n^k)=n$
现在再来看这个式子:$G(k)=sum_{j=0}^{n}a_jsum_{i=0}^{n}(omega_n^{j-k})^i$
当$j=k$的时候,后半部分的值为n,而当$j!=k$的时候,后半部分的值为0.
因此可以得到:$G(k)=nb_k$
所以:$b_k=frac{G(k)}{n}$
至于$G(k)$怎么求呢?别忘了,我们把点值$(y_0,y_1,y_2,y_3,......,y_n)$当作了一个n次多项式$G()$的系数。而$G_k$其实就是多项式$G(k)$的点值。所以求一个已知系数的n-1次多项式的n点值我们用什么?FFT!FFT!FFT!
因此我们再用一边快速傅里叶变换把多项式$A()$、$B()$快速傅里叶变换后相乘所得多项式$F()$的点值当作另一个多项式$G()$的系数求出$G()$的点值表达式。然后$G()$的n个点值各自除n就是$F()$的n个系数系数。
至此,IFFT完结撒花~
这些推导一定要理解不要死背,要不然只会做FFT的板子啊~。
四.实践中创新
1.由于c++自带的复数库complex太慢,因此我们自己定义复数类:
struct complex { double x,y; complex (double xx=0,double yy=0){x=xx,y=yy;} }a[200010],b[200010]; complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);} complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);} complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);}
2.利用FFT运算两个多项式的卷积:
#include <iostream> #include <cstdio> #include <cmath> #define inc(i,a,b) for(register int i=a;i<=b;i++) using namespace std; const double Pi=acos(-1.0); struct complex { double x,y; complex (double xx=0,double yy=0){x=xx,y=yy;} }a[200010],b[200010]; complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);} complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);} complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);} void fft(int nowlimit,complex *now,int type){ if(nowlimit==1) return; complex a1[nowlimit>>1],a2[nowlimit>>1]; for(int i=0;i<=nowlimit;i+=2){ a1[i>>1]=now[i]; a2[i>>1]=now[i+1]; } fft(nowlimit>>1,a1,type); fft(nowlimit>>1,a2,type); complex wn=complex(cos(2.0*Pi/nowlimit),type*sin(2.0*Pi/nowlimit)),w=complex(1,0); for(int i=0;i<(nowlimit>>1);i++,w=w*wn){ now[i]=a1[i]+w*a2[i]; now[i+(nowlimit>>1)]=a1[i]-w*a2[i]; } } int main(){ int n,m; cin>>n>>m; inc(i,0,n) cin>>a[i].x; inc(i,0,m) cin>>b[i].x; int limit=1; while(n+m>=limit) limit<<=1; fft(limit,a,1); fft(limit,b,1); inc(i,0,limit) a[i]=a[i]*b[i]; fft(limit,a,-1); inc(i,0,n+m) printf("%d ",(int)(a[i].x/limit+0.5)); }
我们来观察一下上面的代码,不难看出,这是一个递归版FFT,因此这个FFT慢到家了,连模板都无法AC,甚至比n^2跑的还慢,但是其中有许多技巧我需要简单说一说。
2.1.我们来看这段代码:
complex wn=complex(cos(2.0*Pi/nowlimit),type*sin(2.0*Pi/nowlimit)),w=complex(1,0);
其中注意,wn表示的是在当前区间长度nowlimit下的nowlimit次单位根。也就是说,$wn^{nowlimit}=1$,这意味着把复平面均分成nowlimit份。而nowlimit一定是2的正整数幂。而w表示的就是$omega_{nowlimit}^0$。
2.2.我们再来看这段代码:
for(int i=0;i<(nowlimit>>1);i++,w=w*wn){ now[i]=a1[i]+w*a2[i]; now[i+(nowlimit>>1)]=a1[i]-w*a2[i]; }
请注意,由于我们的数组now是一个指针,所以递归时改变该层的now数组其实就是改变递归上一层时的a1数组或者a2数组。(因为递归时我们调用了fft(nowlimit,a1,type)和fft(nowlimit,a2,type));
而now数组在第一次递归时指向的是原数组本身(系数数组),所以在fft后,原系数数组就变成了点值表达式数组。
有些人可能会问:我们推出来的式子不是$A(omega_n^k)=A_1(omega_n^{2k})+omega_n^kA_2(omega_n^{2k})$吗?怎么到代码里就变成$now[i]=a_1[i]+w*a_2[i]$了呢?说好的平方呢?
其实代码没有错,因为我们wn表示的是在当前递归层区间长度为nowlimit下的nowlimit次单位根,而在上一递归层中的nowlimit是这一层的2倍。因此上一层的wn正好是这一层wn的平方根。也就是说:上一层wn的平方正好是这一层的wn。再换句话说,就是上一层选取的单位根的个数正好是这一层选取单位根个数的两倍,而这等价于上一层所得到的点值的个数是这一层所得点值个数的两倍。
而至于$now[i]=a_1[i]+w*a_2[i]$中的i是什么呢?其实i只是一个寻址符。我们把复平面均分成(1<<n)份,从x轴正半轴开始,以逆时针为正方向。我们在每层递归时按照正方向的顺序依次选取单位根,其中第i个选取的单位根代入多项式所得的答案(重点!)就是$a_1[i]$。$a_2[i]$同理,单位根所选取的数都是一样的,但答案与$a_1[i]$不同。因为虽然单位根的值相同但这次是代入多项式$a2()$时所得的答案。
2.3.在主函数中,我们进行了一次fft(limit,a,-1)。这是在干什么呢?
之前说过,IFFT中把多项式$A()$、$B()$快速傅里叶变换后相乘所得多项式$F()$的点值当作另一个多项式$G()$的系数求出$G()$的点值表达式。然后$G()$的n个点值各自除n就是$F()$的n个系数系数。而$G(k)=sum_{i=0}^{n}y_i(omega_n^{-k})^i$,次数是负数(感性理解一下),因此我们这里的type代入-1。
updata:对于上面一行中次数是负数这一点,我们可以想一想三角函数。因为在x轴上方的角的sin值都为正,x轴下方的角的sin值都为负,y轴右侧的cos值都为正,y轴左侧的cos值都为负。而且在单位根次数取负数时就相当于取以x轴为对称轴的那个单位根。所以sin值变为原来的相反数,cos值不变。
3.关于常数优化
常数什么的才不是关键呢?(啊呸!这卡的比$n^2$还慢而且还爆栈)
常数优化什么的,如果是写的如此糟糕的fft的话就一定能做到的吧。
3.1没错,我们来看一个听起来很nb的操作:蝴蝶变换:
for(int i=0;i<(limit>>1);i++,w=w*Wn) { complex t=w*a2[i]; a[i]=a1[i]+t, a[i+(limit>>1)]=a1[i]-t; }
我们发现了什么?没错,你没看错。仅仅是把$w*a_2[i]$从算两次变成了算一次,这样就优化掉了一个大大的复数乘法啦~
3.2还有什么优化呢?比如说递归变成循环模拟?
你没想错,我们手动用循环模拟递归过程,这样在优化常数的时候同时解决了爆栈这个问题。
但如果我们想要不递归,就要提前知道递归最底层时每一层的系数是多少,而要知道这个,似乎只有递归一条路。是否有其他办法呢?
我们打表观察:
发现了啥?没错,原序列与后序列在二进制表示下的差别只是翻转原序列过得到的。
这样就可以完全避免递归。
我们先上代码,解释在代码之后。
#include <iostream> #include <cstdio> #include <cmath> #define inc(i,a,b) for(register int i=a;i<=b;i++) using namespace std; const double Pi=acos(-1.0); struct complex { double x,y; complex (double xx=0,double yy=0){x=xx,y=yy;} }a[400010],b[400010]; complex operator + (complex a,complex b){ return complex(a.x+b.x , a.y+b.y);} complex operator - (complex a,complex b){ return complex(a.x-b.x , a.y-b.y);} complex operator * (complex a,complex b){ return complex(a.x*b.x-a.y*b.y , a.x*b.y+a.y*b.x);} int limit=1,num=0; int rev[400010]; void fft(complex *now,int type){ inc(i,0,limit-1) if(i<rev[i]) swap(now[i],now[rev[i]]); for(int mid=1;mid<limit;mid<<=1){ //枚举待合并区间的中点 complex wn=complex(cos(Pi/mid),type*sin(Pi/mid)); for(int r=mid<<1,j=0;j<limit;j+=r){ //r是当前区间的大小,j是当前区间的左端点 complex w=complex(1,0); for(int k=0;k<mid;k++,w=w*wn){ //k是当前在区间的什么位置,w是当前的单位根。 complex x=now[j+k],y=now[j+k+mid]*w; now[j+k]=x+y; now[j+k+mid]=x-y; } } } } int main(){ int n,m; cin>>n>>m; inc(i,0,n) cin>>a[i].x; inc(i,0,m) cin>>b[i].x; while(n+m>=limit) limit<<=1,num++; inc(i,0,limit-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(num-1)); //对于这里的解释,最好的理解方法就是对着刚才的那张图自己模拟递推过程。模拟一遍后就知道远原理了。 fft(a,1); fft(b,1); inc(i,0,limit) a[i]=a[i]*b[i]; fft(a,-1); inc(i,0,n+m) printf("%d ",(int)(a[i].x/limit+0.5)); }
3.2.1 关于rev[i]
rev[i]表示原序列第i个元素在后序列的位置是rev[i]。而rev[i]的计算自己模拟一遍就能明白
3.2.2 关于得到后序列
我们在得到后序列的时候加了if(i<rev[i]),这是为了交换就交换一遍。如果不写这段代码,那么交换后的序列和原序列一模一样。
3.2.3. 关于$complex wn=complex(cos(Pi/mid),type*sin(Pi/mid));$
我们会发现,为什么Pi不用*2了呢?这是因为mid指的是区间的一半,而我们要除的是整个区间的长度,所以分母的2和分子的2抵消了。
3.2.4. 关于mid
mid指的是当前区间的中点,但要注意,mid要向上取整。也就是说,mid同样是该区间右半部分的第一个元素。