FFT详解
一些前置知识
任意角
详见高中必修4-1.1.1
定义:按逆时针方向旋转形成的角叫做正角;按顺时针方向旋转形成的角叫做负角
弧度制
详见高中必修4-1.1.2
定义:把长度等于半径长的弧所对的圆心角叫1弧度的角,用符号 rad 表示(一般省略不写),读作弧度
常见弧度数:$360^circ = 2pi $ (180^circ = pi)
多项式
有(n)次多项式(F(x)=a_0+a_1x+a_2x^2+cdots+a_nx^n)
显然,它有(n+1)项
那么该多项式的系数表示就是((a_0,a_1,a_2,cdots,a_n))
按照正常做法,在系数表示下,两个多项式相乘需要(O(n^2))的时间复杂度
用函数、方程的思想来看(F(x)),至少确定函数上((n+1))个点就可以推出唯一的(n)次函数(F(x))
这(n+1)个点的坐标集合({(x_1,y_1),(x_2,y_2),cdots,(x_{n+1},y_{n+1})})称为该多项式的点值表示
在点值表示下,两个多项式相乘只需将对应y坐标相乘
又因为两个(n)项的多项式相乘所得的多项式有(2n+1)项
所以先将(A(x))和(B(x))的点值表达扩展到(2n+1)个点(上面说到的(n+1)个点是至少,实际上还可以再多)
再进行点值表达下的多项式相乘
如下:
(A(x)={(x_1,y_1),(x_2,y_2),cdots,(x_{n+1},y_{n+1})})
(B(x)={(x_1,y_1'),(x_2,y_2'),cdots,(x_{n+1},y_{n+1}')})
扩展后:
(A(x)={(x_1,y_1),(x_2,y_2),cdots,(x_{2n+1},y_{2n+1})})
(B(x)=(x_1,y'_1),(x_2,y'_2),cdots(x_{2n+1},y_{2n+1}))
(C(x)=A(x)B(x)={(x_1,y_1y_1'),(x_2,y_2y_2'),cdots,(x_{2n+1},y_{2n+1}y_{2n+1}')})
可见,在点值表示下,两个多项式相乘只需要(O(n))的时间复杂度
向量
详见高中必修4-2.1
有向线段:带有方向的线段。以(A)为起点,(B)为终点的有向线段记作(vec{AB}),起点写在终点前面;它的长度记作(|vec{AB}|)
三个要素:起点、方向、长度。知道了有向线段的三个要素,它的终点就唯一确定。
向量一般用有向线段表示
复数
定义:记(i=sqrt{-1}),把形如(a+bi)((a)、(b)均为实数)的数称为复数,其中(a)是实部,(b)是虚部
复平面:复数的平面中,(x)轴称为实轴,表示复数的实部;(y)轴称为虚轴,表示复数的虚部。为了方便理解,一般用从原点到点((a,b))的向量表示复数(a+bi)
模长:在复平面上,原点到点((a,b))的距离(sqrt{a^2+b^2})为复数(a+bi)的模长
辐角:在复平面上,实轴与表示复数(a+bi)的向量所形成的正角( heta)称为复数(a+bi)的辐角
运算:复数运算与实数运算相似,如下:
注意!!!(敲黑板*2)有个结论:复数乘法,辐角相加,模长相乘
欧拉公式
公式长这个样子:(e^{ix}=cos (x) + i sin(x))。其中(x)是一个实数
这篇博客中有详细的证明
单位根
数学上,n次单位根是n次幂为1的复数。它们位于复平面的单位圆上,构成正n边形的顶点,其中一个顶点是1。(From Baidu)
单位圆:在复平面上,以原点为圆心,单位长度为半径所作出的圆
(n)次单位根:在复平面上,以原点为起点,单位圆的(n)等分点为终点的向量所表示的(n)个复数中,辐角最小的向量所对应的复数(omega_n),称为(n)次单位根。将与(x)轴正半轴重合的那个向量记为(omega_n^0)(那个上标实际上是乘方),沿逆时针方向将剩余的向量顺序标记为(omega_n^1,omega_n^2,cdots,omega_{n-1})。所以(omega_n^n=omega_n^0=1)。
将单位根和欧拉公式摆在一起,将上面提到的欧拉公式中的(x)换成(2pi)(表示(n)次单位根(omega_n^1)的向量与实轴所形成的夹角所对应的单位圆上的弧的长度),可以得到:$$e^{2pi i}=cos(2pi)+i sin(2pi)=1=omega_n^n$$
所以:(e^{frac{2pi i}{n}}=cos(frac{2pi}{n})+isin(frac{2pi}{n})=omega_n)
我们称此时的单位根(omega_n)为主次单位根
那么其它的单位根就有:(omega_n^k=cos(frac{2pi k}{n})+isin(frac{2pi k}{n})(0leq kleq n))
消去引理:(omega_{dn}^{dk}=omega_n^k) ((n)、(d)、(k)均为整数且(ngeq0,d>0,kgeq0))。证明如下:$$omega_{dn}{dk}=(e{frac{2pi i}{dn}}){dk}=(e{frac{2pi i}{n}})k=omegak_n$$
折半引理:如果(n>0)且(n)为偶数,那么(n)个(n)次单位复数根的平方的集合就是(frac{n}{2})g个(frac n 2)次单位复数根的集合。证明如下:
求和引理:(sum^{n-1}_{j=0}(omega^k_n)^j=0),证明如下:
矩阵
定义:矩阵是一个按照长方阵列排列的实数或负数集合,如下大小为(n imes m)的矩阵(A),其中第(i)行第(j)列的数记作(a_{ij})
矩阵乘法:矩阵运算中比较常用的是矩阵乘法。矩阵乘法只有在第一个矩阵的行数和第二个矩阵的列数相同时才有意义,因此矩阵乘法不满足交换律。矩阵乘法法则为:有(p imes n)的矩阵(A)和(m imes p)的矩阵(B),同时有矩阵(C=A imes B),那么(C_{ij}=sum_{k=1}^{p}a_{ki} imes b_{jk})
范德蒙德矩阵:一个(n imes n)阶的范德蒙德矩阵(V)由包含(n)个数的数列(a_0,a_1,cdots,a_{n-1})决定,其中(v_{ij}=a_j^i(0leq i,j<n)),即:
范德蒙德矩阵很特殊,它的逆矩阵是这样的:
如何加速多项式乘法
一般的多项式乘法题目都是给出两个多项式的系数表示,求这两个多项式相乘后所得的多项式的系数表示
对于这样的多项式乘法,最简单的方法就是两个for循环,也是系数表示下的多项式乘法,时间复杂度(O(n^2))。代码如下,非常简洁直观:
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
ans[i+j-1]=a[i]*b[j];
}
}
太慢了!!!
上面还提到,点值表示下的多项式乘法只需要(O(n))
那如果我们可以很快地将一个多项式的系数表示转为点值表示,然后在点值表示下进行多项式乘法,最后很快地将相乘得到的多项式点值表示转为系数表示,我们就可以在多项式乘法中有一个优秀的时间复杂度了
这里介绍两种变换:
DFT:离散傅里叶变换,全称Discrete Fourier Transform,是指将多项式从系数表示转为点值表示的过程
IDFT:离散傅里叶逆变换,全称Inverse Discrete Fourier Transform,是指将多项式从点值表示转为系数表示的过程
DFT&FFT
DFT
将系数表示转为点值表示最简单的做法就是对多项式(A(x)=a_0+a_1x+a_2x^2+cdots+a_nx^n)中的(x)取(n+1)个不同的值,逐个代入多项式,求出对应的(y)
显然,将单个(x)代入求(y)需要(O(n))的时间,那么将(n+1)个(x)代入求对应(n)个(y)就需要(O(n^2))的时间复杂度
还是太慢了。。。
这个时候就用到FFT加速DFT
FFT
FFT:快速离散傅里叶变换,全称Fast Fourier Transformation,就是加速后的DFT
在FFT中,采取分治的策略优化时间
为了方便分治,一般把多项式的项数补成2的正整数次幂,即将原本没有的更高次项的系数取0
下文默认(n)为2的正整数次幂
有(n)项多项式(A(x)=a_0+a_1x^1+a_2x^2+cdots+a_{n-1}x^{n-1})
对这个多项式进行奇偶分治,得到以下两个多项式:
(F(x)=a_1x+a_3x^3+cdots+a_{n-2}x^{n-2})
(G(x)=a_0+a_2x^2+cdots+a_{n-1}x^{n-1})
每项的次数看上去挺不友善的。。。想办法化简一下,得到以下两个多项式:
(F'(x)=a_1+a_2x+a_5x^2+cdots+a_{n-2}x^{frac{n-3}2})
(G'(x)=a_0+a_2x+a_4x^2+cdots+a_{n-1}x^{frac{n-1}2})
上面几个多项式的关系是这样的:
(F(x)=xF'(x^2))
(G(x)=G'(X^2))
(A(x)=F(x)+G(x)=xF'(x^2)+G'(x^2))
然后为了方便FFT优化,选取的(n)个(x)为(n)次单位根(omega_n^0,omega_n^1,omega_n^2,cdots,omega_n^{n-1})
注意!!!是(n)个而不是(n+1)个,因为这里的多项式项数为(n),即最高次项的次数只到(n-1),所以只用(n)个点就可以确定这个多项式
将(x=omega_n^k)代入上式,得到:(消去引理)
因为在单位复数根(omega_b^a)中,(aleq b)
所以上式只能用于(kleq n/2)的情况
设(k'leq n/2),那么当(k>n/2)时有(k=k'+n/2),代入,得到:
因为(k)和(k')都小于等于(n/2),所以可以看成同一个东西
对比上面两个结果,就可以发现两个式子只有一个常数项不同
所以每次只需要分治下去,求出(F'(omega_{n/2}^k))和(G'(omega_{n/2}^k))(注意这两个是不一样的,它们每项的系数不一样)
然后就可以合并为(A(omega_n^k))
即:
这个分治的过程中,每次合并需要(O(n)),一共分治(logn)次,所以时间复杂度为(O(nlogn))
更多的优化——迭代
上面的分治很容易想到用递归可以实现
但是递归非常耗内存,又跑得慢
所以尝试用迭代代替递归以让FFT更优秀
问题来了,怎样迭代呢?
找一下规律,就可以发现,奇偶分治到最底层((n=1))后,结果序列的二进制形式为原序列的二进制翻转后结果(好像不是特别准确)
直观一点,举个例子:
当(n=8)时,进行奇偶分治:
(0 1 2 3 4 5 6 7)
(0 2 4 6)(1 3 5 7)
(0 4)(2 6)(1 5)(3 7)
(0)(4)(2)(6)(1)(5)(3)(7)
原序列,结果序列的二进制对比:
原序列 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|---|
原序列二进制 | 000 | 001 | 010 | 011 | 100 | 101 | 110 | 111 |
结果序列 | 0 | 4 | 2 | 6 | 1 | 5 | 3 | 7 |
结果序列二进制 | 000 | 100 | 010 | 110 | 001 | 101 | 011 | 111 |
规律就很明显了,就是二进制反过来
具体代码实现中的二进制翻转推导可以看下这篇文,很好理解
所以就可以直接跳过从原序列分治到单项的过程,直接用分治完的结果序列合并上去
成功优化
IDFT&IFFT
如果把DFT看作矩阵乘法,可以得到:
将对应(x_0,x_1,x_2,cdots,x_{n-1})的(omega_n^0,omega_n^1,omega_n^2,cdots,omega_n^{n-1})带入,得到:
化简,得到:
那么可以得到:
不难发现,需要求逆的矩阵是一个范德蒙德矩阵,那么它就可以这样转换一下:
所以上面的式子就可以转化为:
所以只要对IDFT采取类似FFT的奇偶分治优化,就能将(O(n^2))的IDFT加速为(O(nlogn))的IFFT
具体为将FFT中的单位复数根(omega_n^k)取为(omega_n^{-k})
所以可以将FFT和IFFT合到同一个函数中,区别为传入参数中的type为1还是-1
完成(撒花
例题:luogu P3803 【模板】多项式乘法(FFT)
代码如下,详见注释:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=3000000;
const double pi=acos(-1.0);
int n,m,nn,bit,rev[N];
struct Complex{//定义复数类以及它的运算
double a,b;
Complex(){}
Complex(double aa,double bb){
a=aa;b=bb;
}
Complex operator + (const Complex& x) const{
return Complex(a+x.a,b+x.b);
}
Complex operator - (const Complex& x) const{
return Complex(a-x.a,b-x.b);
}
Complex operator * (const Complex& x) const{
return Complex(a*x.a-b*x.b,a*x.b+b*x.a);
}
}a[N],b[N];
void init(){//初始化
for(int i=0;i<N;i++) a[i].a=a[i].b=b[i].a=b[i].b=0;
nn=1;
bit=0;
}
void before_fft(){
while(nn<=n+m+1){ //将多项式项数扩展到2的正整数次幂
nn<<=1;
bit++;
}
for(int i=0;i<nn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));//求对应二进制翻转
}
void fft(Complex* a,int typ){
for(int i=0;i<nn;i++){//将数组变到奇偶分治到底层时的顺序
if(i<rev[i]) swap(a[i],a[rev[i]]); //if不能少,否则会有数交换两次后位置不变
}
for(int i=2;i<=nn;i<<=1){//每次要合并的两个多项式的项数和
Complex wn=Complex(cos(2*pi/i),typ*sin(2*pi/i));//通过欧拉公式求主次单位根
for(int st=0;st<nn;st+=i){//枚举归并位置
Complex wnk=Complex(1,0);
for(int k=0;k<i/2;k++){//枚举w_n的幂数k
Complex x=a[st+k];//对应fft推导过程中的g'k
Complex y=wnk*a[st+k+i/2];//对应fft推导过程中的wnk f'k
a[st+k]=x+y;//对应a_k
a[st+k+i/2]=x-y;//对应a_k+n/2
wnk=wnk*wn;
}
}
}
if(typ==-1){
for(int i=0;i<=nn;i++) a[i].a/=nn;//根据ifft的推导,这里要记得除以nn
}
}
int main(){
init();
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i].a);
for(int i=0;i<=m;i++) scanf("%lf",&b[i].a);
before_fft();
fft(a,1);
fft(b,1);
//fft后的a,b分别为这两个多项式在点值表达下当x取wn1,wn2,...,wn(n-1)时的y
for(int i=0;i<=nn;i++) a[i]=a[i]*b[i];//点值表达下的多项式乘法,只将对应y相乘
fft(a,-1);//ifft
for(int i=0;i<=n+m;i++) printf("%d ",(int)(a[i].a+0.5));//这里要四舍五入
printf("
");
return 0;
}