怎么说呢,快速xxx变换,都是运用了分治的思想来实现对卷积的加速,把n^2硬是降成log级的。
下面以快速傅里叶变换为例,记录一下其具体推导过程:
首先要利用单位复根的性质,证明几个引理:
引理一: w[d*n]^(d*k)=w[n]^k
证明:
w[d*n]^(d*k)
=e^(2*Pi*i*k*d/(2*d*n))
=e^(2*Pi*i*k/(2*n))
=w[n]^k
证毕
引理二: (w[n]^k)^2=w[n/2]^k
证明:
(w[n]^k)^2
=w[n]^(2*k)
=w[n/2]^k
证毕
引理三: Sigma[j=0,n-1]((w[n]^k)^j)=0
证明:
画复平面单位圆,显然
证毕
然后,开始分治:
A(x)=a0 + a1 *x + a2 *x^2 + ... + an *x^n
=A [a_tag&1 == 0] (x^2) + x *A [a_tag&1 == 1] (x^2)
分治完毕
FFT基本推导完成
附上代码:
FFT 快速傅里叶变换
void FFT(Comp *a,const int &n,const short &rev){ if(n==1) return ; for(int i=0;i<n;++i) tmp[i]=a[i]; for(int i=0;i<n;++i) if(i&1) a[(n>>1)+(i>>1)]=tmp[i]; else a[i>>1]=tmp[i]; Comp *a0=a; Comp *a1=a+(n>>1); FFT(a0,n>>1,rev); FFT(a1,n>>1,rev); Comp cur(1,0); const double alpha=PI*2/n*rev; Comp step=exp(I*alpha); for(int k=0;k<(n>>1);++k){ tmp[k]=a0[k]+cur*a1[k]; tmp[k+(n>>1)]=a0[k]-cur*a1[k]; cur*=step; } for(int i=0;i<n;++i) a[i]=tmp[i]; }
NTT 快速数论变换
#include<bits/stdc++.h> #define ll long long #define MAXN 1<<20 using namespace std; const int p=479<<21|1; ll inv,gn[MAXN],gn_inv[MAXN]; int n; inline ll qpow(ll a,ll b){ ll ans=1; while(b){ if(b&1) ans=ans*a%p; a=a*a%p; b>>=1; } return ans; } inline void rev(const int &n,ll r[]){ for(int i=0,j=0;i<n;++i){ if(i>j) swap(r[i],r[j]); for(int l=n>>1;(j^=l)<l;l>>=1); } } inline void prelude(){ inv=qpow(3,p-2); for(int i=1;i<=n;i<<=1) gn[i]=qpow(3,(p-1)/i), gn_inv[i]=qpow(inv,(p-1)/i); } inline void NTT(int n,ll* r,short f){ rev(n,r); for(int i=2;i<=n;i<<=1){ int m=i>>1; for(int j=0;j<n;j+=i){ ll w=1,wn=(f==1?gn[i]:gn_inv[i]); for(int k=0;k<m;++k){ ll z=r[j+m+k]*w%p; r[j+m+k]=(r[j+k]-z+p)%p; r[j+k]=(r[j+k]+z)%p; w=w*wn%p; } } } if(f==-1){ ll n_inv=qpow(n,p-2); for(int i=0;i<n;++i) r[i]=r[i]*n_inv%p; } }
FWT 快速沃尔什变换
class FWT{ public: inline void fwt(int *a, int n){ for(int d = 1; d < n; d <<= 1){ for(int m = d<<1, i = 0; i < n; i += m){ for(int j = 0; j < d; j++){ int x = a[i+j], y = a[i+j+d]; a[i+j] = x+y; a[i+j+d] = x-y; //and a[i+j] = x+y; //or a[i+j+d] = x+y; } } } } inline void ufwt(int *a, int n){ for(int d = 1; d < n; d <<= 1){ for(int m = d<<1, i = 0; i < n; i += m){ for(int j = 0; j < d; j++){ int x = a[i+j], y = a[i+j+d]; a[i+j] = (x+y)/2; a[i+j+d] = (x-y)/2; //and a[i+j] = x-y //or a[i+j] = y-x } } } } inline void cal(int *a, int *b, int n){ fwt(a, n); fwt(b, n); for(int i = 0; i < n; i++) a[i] *= b[i]; ufwt(a, n); } }fwt;