zoukankan      html  css  js  c++  java
  • 浅谈FFT&NTT

    复数及单位根

    复数的定义大概就是:(i^2=-1),其中(i)就是虚数单位。

    那么,在复数意义下,对于方程:

    [x^n=1 ]

    就必定有(n)个解,这(n)个解的分布一定是在复平面上,以圆点为圆心,半径为(1)的圆的(n)等分点。

    由于欧拉公式:

    [e^{i heta}=cos heta+icdot sin heta ]

    (2pi)带入:

    [e^{2ipi}=1 ]

    比较一下这个和上面的方程,设:

    [omega_n=e^{2ipi/n} ]

    那么可以得到上面方程的(n)个解分别为:

    [forall iin[0,n-1],x_i=omega_n^i ]

    那么,我们称这(n)个解为(n)次单位根。

    关于单位根,有以下性质:

    [omega_n^x=-omega_n^{x+frac{n}{2}},w_n^2=w_{frac{n}{2}} ]

    这些性质的证明都很简单。

    点值表达式

    考虑到,一个多项式可以看做是一个(n)次的函数,如果已知这个函数的(n+1)个点,那么就可以确定这个多项式。

    任取(n+1)个不同的数(x_i),知道了多项式的结果(F(x_i)),这个称作多项式的点值表达式

    离散傅里叶变换(Discrete Fourier Transform, DFT)

    对于一个(n-1)次多项式,取(n)个数(w_n^0,w_n^1...w_n^{n-1}),得到一个点值表达式,称作离散傅里叶变换

    先把这个多项式凑成(n=2^x)的形式,高位补(0)

    对于(F(omega_n^{k})),显然可以得到:

    [F(omega_n^k)=sum_{i=0}^{n-1}(omega_n^k)^icdot A_i ]

    其中(A_i)为系数。

    然后对这个进行奇偶分类,可得:

    [egin{align} F(omega_n^k)&=sum_{i=0}^{n/2-1}(omega_n^{k})^{2i}cdot A_{2i}+sum_{i=0}^{n/2-1}(omega_n^k)^{2i+1}cdot A_{2i+1}\ &=sum_{i=0}^{n/2-1}(omega_{n/2}^{k})^icdot A_{2i}+omega_n^kcdot sum_{i=0}^{n/2-1}(omega_{n/2}^k)^{i}cdot A_{2i+1} end{align} ]

    (F_0(x))为偶数项的系数构成的多项式,(F_1(x))为奇数项,这个显然是一个子问题。

    那么:

    [F(omega_n^k)=F_0(omega_{n/2}^k)+w_n^kcdot F_1(omega_{n/2}^k) ]

    所以,令(kleqslant n/2),则有:

    [F(omega_n^{k+n/2})=F_0(omega_{n/2}^k)+w_n^{k+n/2}cdot F_1(omega_{n/2}^k) ]

    即:

    [F(omega_n^{k})=F_0(omega_{n/2}^k)+w_n^{k}cdot F_1(omega_{n/2}^k) \F(omega_n^{k+n/2})=F_0(omega_{n/2}^k)-w_n^{k}cdot F_1(omega_{n/2}^k) ]

    递归计算即可,复杂度:

    [T(n)=2 cdot T(frac{n}{2})+O(n)=O(nlog n) ]

    离散傅里叶逆变换(Inverse Discrete Fourier Transform, IDFT)

    对于离散傅里叶变换,写成矩阵的形式就是:

    [egin{bmatrix} (omega_n^0)^0&(omega_n^0)^1&cdots & (omega_n^0)^{n-1}\ (omega_n^1)^0&(omega_n^1)^1&cdots & (omega_n^1)^{n-1}\ vdots&vdots&ddots&vdots\ (omega_n^{n-1})^0&(omega_n^{n-1})^1&cdots & (omega_n^{n-1})^{n-1}\ end{bmatrix} imes egin{bmatrix} A_0\A_1\vdots\A_{n-1} end{bmatrix} = egin{bmatrix} F(omega_n^0)\F(omega_n^1)\vdots\F(omega_n^{n-1}) end{bmatrix} ]

    现在,我们是知道了等号右边的(F),要求等号左边的(A)

    设上面的系数矩阵为(s),考虑下面这个矩阵,设为(t)

    [t=egin{bmatrix} (omega_n^{-0})^0&(omega_n^{-0})^1&cdots & (omega_n^{-0})^{n-1}\ (omega_n^{-1})^0&(omega_n^{-1})^1&cdots & (omega_n^{-1})^{n-1}\ vdots&vdots&ddots&vdots\ (omega_n^{-(n-1)})^0&(omega_n^{-(n-1)})^1&cdots & (omega_n^{-(n-1)})^{n-1}\ end{bmatrix} ]

    考虑矩阵(v=t imes s)

    对于(v_{i,j}),根据矩阵乘法规则,它会等于:

    [v_{i,j}=sum_{k=0}^{n-1}(omega_n^{-i})^{k}cdot (omega_{n}^{k})^{j}=sum_{k=0}^{n-1}omega_n^{k(j-i)} ]

    (i=j),则:

    [v_{i,j}=n ]

    否则:

    [v_{i,j}=sum_{k=0}^{n-1}omega_n^{k(j-i)}=frac{1-(omega_n^{j-i})^n}{1-omega_n^{j-i}} ]

    注意到:

    [omega_n^n=0 ]

    所以:

    [v_{i,j}=0 ]

    然后把这个矩阵写出来:

    [v=egin{bmatrix} n&0&cdots&0\ 0&n&cdots&0\ vdots&vdots&ddots&vdots\ 0&0&cdots&n end{bmatrix} ]

    然后可以发现,这个就是单位矩阵的(n)倍,即:

    [t imes s=ncdot epsilon ]

    然后考虑第一个矩阵的式子,等式两边同时左乘一个(t),可得:

    [ncdot egin{bmatrix} A_0\A_1\vdots\A_{n-1} end{bmatrix} = egin{bmatrix} (omega_n^{-0})^0&(omega_n^{-0})^1&cdots & (omega_n^{-0})^{n-1}\ (omega_n^{-1})^0&(omega_n^{-1})^1&cdots & (omega_n^{-1})^{n-1}\ vdots&vdots&ddots&vdots\ (omega_n^{-(n-1)})^0&(omega_n^{-(n-1)})^1&cdots & (omega_n^{-(n-1)})^{n-1}\ end{bmatrix} imes egin{bmatrix} F(omega_n^0)\F(omega_n^1)\vdots\F(omega_n^{n-1}) end{bmatrix} ]

    所以,(IDFT)的时候直接照搬(DFT),然后把(omega_n^k)改成(omega_n^{-k}),最后在除个(n)就好了。

    迭代实现

    由于上面的递归实现常数过大,不是很优秀,这里有一种迭代的实现方法。

    考虑我们把递归过程改成迭代,那么显然我们需要把顺序重新排列一下,然后每次把相邻的(2^k)个数合并就好了。

    (n=2^m),考虑第(i)次递归的时候,二进制下第(i)(0)的放左边,为(1)的放右边,那么可以发现,左边的所有数新位置的编号第(m-i+1)位都为(0),右边的为(1),这个可以自己画下图理解下。

    那么,设(rev(x))表示把(x)的二进制翻转的结果,即第(i)位和第(m-i+1)位交换。

    对于原序列第(i)个数,他在新序列的位置就应该是(rev(i))

    代码就比较好写了:

    #include<cmath>
    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    using namespace std;
     
    void read(int &x) {
        x=0;int f=1;char ch=getchar();
        for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
        for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
    }
     
    void print(int x) {
        if(x<0) putchar('-'),x=-x;
        if(!x) return ;print(x/10),putchar(x%10+48);
    }
    void write(int x) {if(!x) putchar('0');else print(x);putchar('
    ');}
    
    const int maxn = 4e6+10;
    
    #define lf double
    
    const lf pi = acos(-1);
    
    struct complex {
    	lf real,imag;
    	complex () {}
    	complex (lf _real,lf _imag) {real=_real,imag=_imag;}
    	complex conj() {return complex(real,-imag);}  //共轭复数
    	complex operator = (const int &rhs) {real=rhs;return *this;}
    	complex operator + (const complex &rhs) const {return complex(real+rhs.real,imag+rhs.imag);}
    	complex operator - (const complex &rhs) const {return complex(real-rhs.real,imag-rhs.imag);}
    	complex operator * (const complex &rhs) const {return complex(real*rhs.real-imag*rhs.imag,imag*rhs.real+real*rhs.imag);}
    };   //手写的一个复数类
    
    complex es[maxn],ces[maxn],a[maxn],b[maxn];
    int n,m,N,pos[maxn],bit;
    
    void init() {
    	for(int i=0;i<N;i++) es[i]=complex(cos(2*pi/N*i),sin(2*pi/N*i));
    	for(int i=0;i<N;i++) ces[i]=es[i].conj();  //预处理单位根
    	for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));  //pos[x]表示rev(x)
    }
    
    void fft(complex *r,complex *w) {
    	for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);  //调整位置
    	for(int i=1;i<N;i<<=1) 
    		for(int j=0;j<N;j+=(i<<1))
    			for(int k=0;k<i;k++) {
    				complex x=r[j+k],y=w[N/(i<<1)*k]*r[j+k+i];  //迭代实现
    				r[j+k]=x+y,r[i+j+k]=x-y;
    			}
    }
    
    int main() {
    	read(n),read(m);
    	for(int i=0,x;i<=n;i++) read(x),a[i]=x;
    	for(int i=0,x;i<=m;i++) read(x),b[i]=x;
    	N=1;while(N<=n+m) N<<=1,bit++;
    	init();fft(a,es),fft(b,es);
    	for(int i=0;i<=N;i++) a[i]=a[i]*b[i];fft(a,ces);
    	for(int i=0;i<=n+m;i++) printf("%d ",(int)(a[i].real/N+0.5));puts("");  //记得答案要除N,这个其实应该写在fft函数里面。。
    	return 0;
    }
    

    这份代码在洛谷的模板P3803 【模板】多项式乘法(FFT)提交可以通过。

    快速数论变换(Fast Number-Theoretic Transform,FNT)

    这玩意其实一般叫做(NTT)

    考虑到上面(FFT)的过程用到了单位根的哪些性质:

    1. (omega_n^0,omega_n^1...omega_n^{n-1})互不相同,这保证了点值表达式可以成立。
    2. (omega_n^2=omega_{n/2})(omega_n^{k+n/2}=-omega_n^k)
    3. (omega_n^n=1),这保证了IDFT的正确性。

    对于模数(p=kcdot 2^s+1),且(p)为质数,设它的原根为(g),那么我们可以令(omega_n=g^{(p-1)/n})

    由于原根的性质,第一条显然是满足的。

    对于第二条:

    [omega_n^2=g^{2(p-1)/n}=g^{(p-1)/(n/2)}=omega_{n/2} ]

    并且:

    [omega_n^{n/2}=g^{(p-1)/2}=-1 ]

    也比较显然。

    对于第三点,其实就是费马小定理,显然满足,所以我们可以用这个来替代(omega_n),进行数论变换,代码也差不多。

    注意,对于质数(p=kcdot 2^s+1),它能处理的数据范围是(nleqslant 2^s)

    模板:题目和上题相同

    #include<bits/stdc++.h>
    using namespace std;
     
    void read(int &x) {
        x=0;int f=1;char ch=getchar();
        for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
        for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
    }
     
    void print(int x) {
        if(x<0) putchar('-'),x=-x;
        if(!x) return ;print(x/10),putchar(x%10+48);
    }
    void write(int x) {if(!x) putchar('0');else print(x);putchar('
    ');}
    
    const int maxn = 4e6+10;
    const int mod = 998244353;
    
    int n,m,N=1,bit,pos[maxn],es[maxn],ces[maxn],a[maxn],b[maxn];
    
    int qpow(int aa,int x) {
    	int res=1;
    	for(;x;x>>=1,aa=1ll*aa*aa%mod) if(x&1) res=1ll*res*aa%mod;
    	return res;
    }
    
    void ntt(int *r,int f) {
    	for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
    	for(int i=1;i<N;i<<=1) {
    		int wn=qpow(f==1?3:qpow(3,mod-2),(mod-1)/(i<<1));
    		for(int j=0,w=1;j<N;j+=(i<<1),w=1) 
    			for(int k=0;k<i;k++,w=1ll*w*wn%mod) {
    				int x=r[j+k],y=1ll*w*r[i+j+k]%mod;
    				r[j+k]=(x+y)%mod,r[i+j+k]=(x-y)%mod;
    			}
    	}
    }
    
    int main() {
    	read(n),read(m);
    	for(int i=0;i<=n;i++) read(a[i]);
    	for(int i=0;i<=m;i++) read(b[i]);
    	while(N<=n+m) N<<=1,bit++;
    	for(int i=0;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
    	ntt(a,1),ntt(b,1);
    	for(int i=0;i<=N;i++) a[i]=1ll*a[i]*b[i]%mod;
    	ntt(a,-1);int inv=qpow(N,mod-2);
    	for(int i=0;i<=n+m;i++) printf("%d ",((1ll*a[i]*inv%mod)+mod)%mod);puts("");
    	return 0;
    }
    

    任意模数NTT(MTT)

    设现在要算的是多项式(A imes B),模数可能不满足(p=kcdot 2^s+1),甚至可以不是个质数。

    如果直接(FFT)的话,显然会爆精度,现在考虑如何优化精度。

    (r=lceilsqrt{p} ceil),那么对于多项式的每一项,设系数为(s),显然可以写成(s=acdot r+b)的形式。.

    那么对于(scdot t),设(s=acdot r+b,t=ccdot r+d),那么(scdot t=accdot r^2+(ad+bc)cdot r+bd)

    所以,可以把一个多项式拆成两个,分别做(FFT),这样精度一般是不会爆的。

    然后正反一共做(8)(FFT)就好了。

    好像有只需要做4遍FFT的方法,以后填坑。。

    代码:题目出自【模板】任意模数NTT

    #include<cmath>
    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    using namespace std;
    
    void read(int &x) {
        x=0;int f=1;char ch=getchar();
        for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
        for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
    }
     
    void print(int x) {
        if(x<0) putchar('-'),x=-x;
        if(!x) return ;print(x/10),putchar(x%10+48);
    }
    void write(int x) {if(!x) putchar('0');else print(x);putchar('
    ');}
    
    #define lf double 
    
    const int maxn = 4e5+10;
    
    typedef long long ll;
    
    struct complex {
    	lf r,i;
    	complex () {}
    	complex (lf _r,lf _i) {r=_r,i=_i;}
    	complex conj() {return complex(r,-i);}
    	complex operator = (const int &rhs) {r=rhs;return *this;}
    	complex operator - (const complex &rhs) const {return complex(r-rhs.r,i-rhs.i);}
    	complex operator + (const complex &rhs) const {return complex(r+rhs.r,i+rhs.i);}
    	complex operator * (const complex &rhs) const {return complex(r*rhs.r-i*rhs.i,r*rhs.i+i*rhs.r);}
    }w1[maxn],w2[maxn],a[maxn],b[maxn],c[maxn],d[maxn];
    
    int N,bit,n,m,s[maxn],t[maxn],mod,p,ans[maxn],pos[maxn];
    
    const lf pi = acos(-1);
    
    void init() {
    	N=1,bit=0;while(N<=n+m) N<<=1,bit++;
    	for(int i=0;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
    	w1[0]=1;for(int i=1;i<N;i++) w1[i]=complex(cos(pi*2*i/N),sin(pi*2*i/N));
    	for(int i=0;i<N;i++) w2[i]=w1[i].conj();
    }
    
    void fft(complex *r,complex *w,int f) {
    	for(int i=0;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
    	for(int i=1;i<N;i<<=1)
    		for(int j=0;j<N;j+=(i<<1))
    			for(int k=0;k<i;k++) {
    				complex x=r[j+k],y=w[N/(i<<1)*k]*r[i+j+k];
    				r[j+k]=x+y,r[i+j+k]=x-y;
    			}
    	if(f==-1) for(int i=0;i<N;i++) r[i].r/=N,r[i].i=0;
    }
    
    void mul(int *A,int *B,int *C) {
    	for(int i=0;i<=n;i++) a[i]=A[i]/p,b[i]=A[i]%p;
    	for(int i=0;i<=m;i++) c[i]=B[i]/p,d[i]=B[i]%p;
    	init();
    	fft(a,w1,1),fft(b,w1,1),fft(c,w1,1),fft(d,w1,1);
    	for(int i=0;i<N;i++) {
    		complex tmpa=a[i],tmpb=b[i],tmpc=c[i],tmpd=d[i];
    		a[i]=tmpa*tmpc,b[i]=tmpa*tmpd+tmpb*tmpc,c[i]=tmpb*tmpd;
    	}
    	fft(a,w2,-1),fft(b,w2,-1),fft(c,w2,-1);
    	for(int i=0;i<N;i++) {
    		ll tmpa=ll(a[i].r+0.5),tmpb=ll(b[i].r+0.5),tmpc=ll(c[i].r+0.5);
    		C[i]=(tmpa%mod*p%mod*p%mod+tmpb%mod*p%mod-mod+tmpc)%mod;
    	}
    }
    
    int main() {
    	read(n),read(m),read(mod);p=sqrt(mod)+1;
    	for(int i=0;i<=n;i++) read(s[i]),s[i]%=mod;
    	for(int i=0;i<=m;i++) read(t[i]),t[i]%=mod;
    	mul(s,t,ans);
    	for(int i=0;i<=n+m;i++) printf("%d ",(ans[i]+mod)%mod);puts("");
    	return 0;
    }
    
  • 相关阅读:
    绝对干货:供个人开发者赚钱免费使用的一些好的API接口
    科普技术贴:个人开发者的那些赚钱方式
    北漂程序员的笑与泪
    非著名程序员公众号
    北漂程序员的笑与泪
    【有人@我】Android中高亮变色显示文本中的关键字
    新时代的coder如何成为专业程序员
    自定义圆形控件RoundImageView并认识一下attr.xml
    偷天换日:网络劫持,网页js被伪装替换。
    jeesite 去掉 /a
  • 原文地址:https://www.cnblogs.com/hbyer/p/10325916.html
Copyright © 2011-2022 走看看