zoukankan      html  css  js  c++  java
  • 再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)

    再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)

    写在前面

    为了不使篇幅过长,预计将把基于论文的学习笔记分为三部分:

    1. DFT,IDFT,FFT的定义,实现与证明:快速傅里叶变换(FFT)学习笔记(其一)
    2. NTT的实现与证明:快速傅里叶变换(FFT)学习笔记(其二)
    3. 任意模数NTT与FFT的优化技巧

    一些约定

    1. ([p(x)]=egin{cases}1,p(x)为真 \ 0,p(x)为假 end{cases})
    2. 本文中序列的下标从0开始
    3. (s)是一个序列,(|s|)表示(s)的长度
    4. 若大写字母如(F(x))表示一个多项式,那么对应的小写字母如(f)表示多项式的每一项系数,即(F(x)=sum_{i=0}^{n-1} f_ix^i)

    循环卷积

    DFT卷积的本质

    考虑在(其一)中提到的卷积的定义式。

    [c_{r}=sum_{p, q}[(p+q) mod n=r] a_{p} b_{q} ag{1.1} ]

    我们一般做FFT时忽略了式子中的(mod),其实它是在(mod 2^q)的意义下的循环卷积,只是因为(|a|,|b|,|c|<2^q),所以取不取模都没什么影响。

    如果序列长度(n)是2的整数次幂,那么直接做就可以了。

    如果序列长度(n)不是2的整数次幂考虑暴力的做法:先做一次普通FFT,再把(c_{k+n})加到(c_k)上。但是这样在做多次FFT时就必须一次一次做,比如多项式快速幂。下面给出了一种在(O(n log n))的时间内实现任意长度循环卷积的算法:Bluestein’s Algorithm

    Bluestein’s Algorithm

    注:原论文的推导可能有误

    考虑DFT的式子

    [egin{aligned} a'_i&=sum_{j=0}^{n-1} a_j omega_n^{ij} \&=sum_{j=0}^{n-1} a_j omega_n^{frac{-(i-j)^2+i^2+j^2}{2}} \&= omega_n^{frac{i^2}{2}} sum_{j=0}^{n-1}a_j omega_n^{frac{j^2}{2}} omega_n^{-frac{(i-j)^2}{2}}end{aligned} ]

    不妨设

    (x_j=a_j omega_n^{frac{j^2}{2}}=a_j(cosfrac{j^2pi}{n}+ ext{i}sin{frac{j^2pi}{n}}))

    (y_j=omega_n^{-frac{j^2}{2}}= cos frac{pi j^2}{n}- ext{i}sin frac{pi j^2}{n})

    那么(a_i'=omega_n^{frac{j^2}{2}}sum_{j=0}^{n-1} x_j y_{i-j})

    这已经很类似卷积的形式了,但是注意到(j)的上界是(n-1)而不是(i),(j-i)可能为负数。那么我们把(y)数组的长度扩大到(2n),定义:

    (y_j=omega_n^{-frac{(j-n)^2}{2}}= cos frac{pi (j-n)^2}{n}- ext{i}sin frac{pi (j-n)^2}{n}).

    这样(j<n)的时候就对应了(j-i)为负数的情形,(jgeq n)就对应了(j-i)为正的情形。然后对(x)(y)用一般的FFT,最后的答案存储在(i+n)的位置上,也就是说真正的(a'_i)实际上对应了乘积结果的((x cdot y)_{i+n})

    这样,我们就只做了3次FFT就求出了任意长度循环DFT。逆变换同理,只是换成共轭复数。注意到在上述的推导中我们没有用到单位根(omega)的任何性质,因此这里的(omega)可以换成任意复数(z),这样的变换称为Chirp Z-Transform,CZT.可见,CZT实际上是DFT的广义形式。

    代码实现:

    //com是手写复数类,省略
    void fft(com *x,int *rev,int n,int type){
    	//为节约篇幅,fft部分省略,x为系数序列,rev为反转数组,n为长度,type=1表示DFT,type=-1表示IDFT
    } 
    void bluestein(com *a,int n,int type){ 
        //a为系数序列,n为长度,type=1表示DFT,type=-1表示IDFT
    	static com x[maxn*4+5],y[maxn*4+5];
    	static int rev[maxn*4+5];
    	memset(x,0,sizeof(x));
    	memset(y,0,sizeof(y));
        //FFT前的预处理
    	int N=1,L=0;
    	while(N<n*4){
    		L++;
    		N*=2;
    	}
    	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
        //x[i],y[i]的定义见上式
    	for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i];
    	for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n));
    	fft(x,rev,N,1);
    	fft(y,rev,N,1);
    	for(int i=0;i<N;i++) x[i]*=y[i];
    	fft(x,rev,N,-1);
    	for(int i=0;i<n;i++){
    		a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//记得乘上常数
    		if(type==-1) a[i]/=n;//一定记得除以n,因为做一次Bluestein相当于一次FFT,IFFT最后要除n,这里也要除n 
    	} 
    }
    

    例题

    [POJ 2821]TN's Kindom III(任意长度循环卷积的Bluestein算法)

    分治FFT

    一般我们用FFT的时候,序列的所有元素都已知。但是,如果序列本身是根据卷积定义的,就无法直接套FFT

    举一个最简单的例子(f_i =sum_{j=1}^i f_{i-j}g_j).其中(g)给定,求(f). 由于我们卷积的时后后面的数基于前面的数,无法快速计算,时间复杂度退化到(O(n^2)). (虽然这个式子可以用(其四)中将会提到的多项式求逆解决,但是分治FFT更通用,可以处理很复杂的式子)

    考虑分治: 设当前分治区间为([l,r]),假设我们求出了([l,mid])的答案,那么可以求出这些点对([mid+1,r])的影响。那么右半边的点(x in [mid+1,r])得到的贡献是(Delta_x=sum_{i=l}^{mid} f_i g_{x-i}).只需要把下标偏移一下(如([l,mid])偏移成([0,mid-l]),就是一个卷积的形式,可以运用FFT或NTT计算,计算完之后,把答案累加到数组上.

    伪代码如下:

    poly f,g;//上述的f,g
    procedure calc(L,mid,R){
    	for i in [L,mid] : a[i-L] <- f[i]//下标偏移
    	for i in [1,R-L] : b[i-1] <- g[i]
    	a <- mul(a,b);//fft或ntt做多项式乘法
    	for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加贡献
    }
    procedure solve(l,mid){
    	if(l==r) return;
    	mid <- (l+r)/2
    	solve(l,mid);
    	calc(l,mid,r);
    	solve(mid+1,r)
    }
    

    时间复杂度分析:

    (T(n)=2T(frac{n}{2})+n log_2n), 总复杂度(Theta(n log^2n))

    下面是基于NTT的模板代码(Luogu 4721)

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath> 
    #define maxn 300000
    #define G 3
    #define invG 332748118
    #define inv2 499122177
    #define mod 998244353
    using namespace std;
    typedef long long ll;
    inline ll fast_pow(ll x,ll k){
    	ll ans=1;
    	while(k){
    		if(k&1) ans=ans*x%mod;
    		x=x*x%mod;
    		k>>=1;
    	}
    	return ans;
    }
    inline ll inv(ll x){
    	return fast_pow(x,mod-2); 
    }
    
    void NTT(ll *x,int n,int type){
    	static int rev[maxn+5];
    	int tn=1;
    	int k=0;
    	while(tn<n){
    		tn*=2;
    		k++;
    	}
    	for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
    	for(int i=0;i<n;i++){
    		if(i<rev[i]) swap(x[i],x[rev[i]]);
    	} 
    	for(int len=1;len<n;len*=2){
    		int sz=len*2;
    		ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
    		for(int l=0;l<n;l+=sz){
    			int r=l+len-1;
    			ll gnk=1;
    			for(int i=l;i<=r;i++){
    				ll tmp=x[i+len];
    				x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
    				x[i]=(x[i]+gnk*tmp%mod)%mod;
    				gnk=gnk*gn1%mod;
    			} 
    		} 
    	}
    	if(type==-1){
    		int invsz=inv(n);
    		for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod; 
    	}
    }
    void mul(ll *a,ll *b,ll *ans,int sz){
    	NTT(a,sz,1);
    	NTT(b,sz,1);
    	for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod;
    	NTT(ans,sz,-1);
    } 
    
    
    void cdq_divide(ll *f,ll *g,int l,int r){
    	static ll tmpa[maxn+5],tmpb[maxn+5];
    	if(l==r) return; 
    	int mid=(l+r)>>1;
    	cdq_divide(f,g,l,mid);
    	int tn=1,k=0;
    	while(tn<r-l){
    		k++;
    		tn*=2; 
    	}
    	for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0; 
    	for(int i=l;i<=mid;i++) tmpa[i-l]=f[i];
    	for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i];
    	mul(tmpa,tmpb,tmpa,tn);
    	for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod;
    	cdq_divide(f,g,mid+1,r);
    }
    
    int n;
    ll f[maxn+5],g[maxn+5];
    int main(){
    	scanf("%d",&n);
    	for(int i=1;i<n;i++) scanf("%lld",&g[i]); 
    	f[0]=1;
    	cdq_divide(f,g,0,n-1);
    	for(int i=0;i<n;i++) printf("%lld ",f[i]); 
    } 
    

    容易发现,许多dp方程都有分治FFT的形式。对于此类dp方程,我们可以用分治FFT将转移复杂度由(O(n^2))降到(O(n log^2 n))

    例题

    [Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)

    FFT的弱常数优化

    下面介绍一些优化FFT的常数的技巧。虽然这些技巧都只是对FFT的一些小优化,但是在某些题目中优化效果极其明显。

    复杂算式中减少FFT次数

    如果我们要计算一个复杂的多项式,如(A(x)=B(x)C(x)+D(x)E(x))

    最简单的方法是分别计算(B(x)C(x))(D(x)E(x)),这样需要做6次FFT. 但是如果先对(B,C,D,E)做DFT,然后直接用点值表达式计算(a_i=b_ic_i+d_ie_i),再把(a)IDFT回去。这样只需要做5次FFT,且多项式越复杂,这样的常数就越优秀。

    例题

    [BZOJ 3771] Triple(FFT+容斥原理+生成函数)

    利用循环卷积

    考虑对于两个长度为(n)的序列(a,b),计算它们的卷积(c)的第(0.5n)项到第(1.5n)项。传统的方法是补0扩充到(2n)的序列。但是因为FFT求得实际上是我们已经提到过的循环卷积,所以如果只补0到(1.5n)(上取整),对第(0.5n)项到第(1.5n)项无影响

    在基于牛顿迭代的算法中,能起到较明显的优化作用。会在(其四)中详细介绍这些算法。

    小范围暴力

    由于FFT的常数较大。在数据范围较小的时候甚至不如(O(n^2))的暴力卷积的优秀。因此在做多次FFT和分治FFT的时候,如果当前的序列长度较小,可以采用暴力算法。

    例题

    [BZOJ 3509] [CodeChef] COUNTARI (FFT+分块)

    快速幂乘法次数的优化

    这个东西实际上比较鸡肋。因为多项式快速幂可以通过多项式(ln)(exp)优化到(O(n log n)).但是为了应对考场上时间不够的情况,我们来考虑如何通过简单的实现来减少(O(n log^2n))的倍增快速幂的复杂度。

    倍增法的思路是根据前面算过的乘积快速算出当前的乘积,如(1 o 2 o 4 o 8).最坏情况下需要(2 log_2n+C)次乘法。但这并不是下界。我们定义additional chain为一条链,最开始是1,后一个数减前一个数的差是链上这个是前面的某一个数。例如(1 o 2 o 4 o 6).(6-4=2)在前面出现过,(4-2=2)在前面出现过。那么根据这条additional chain计算6次幂的时候,可以从1次幂出发,用1次幂乘1次幂得到2次幂,再乘2次幂得到4次幂,再乘2次幂得到6次幂。

    很可惜,对于数(k)求出得到(k)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基于BFS。每次我们对于队头的数(x),枚举它对应的additional chain中的数(y),如果(x+y)还没有访问过那么将其入队,并将(x)对应的链后面接上(x+y). 这个预处理是(O(k))的,且对快速幂的常数优化很显著。

    如果(k)很大,比如(10^{10000}),可以采用十进制快速幂。但是用Method of Four Russians(俗称四毛子算法),可以将乘法次数减少到(log_2n+O(frac{log n}{log log n})).具体方法见2017年国家集训队论文《非常规大小分块算法初探》

    FFT的强常数优化

    FFT的强常数优化一般是通过减少FFT次数来实现的
    在这一节中,我们记(DFT(A(x)))表示多项式(A(x))(或序列)做DFT之后的结果,(IDFT(A(x)))同理

    我们现在考虑最常见的一个模型:给出两个长度为(n+1)(m+1)的多项式(A(x),B(x)),我们要计算他们的线性卷积。假设长度已经补齐为第一个大于(n+m+1)的2的整数幂(L)

    显然直接搞需要3次长度为(L)的FFT。毒瘤的Vladimir Smykalov在cf上最先给出了这个问题的优化算法。

    DFT的合并

    DFT的合并是指,对于两个序列(a),(b),我们只通过一次FFT就求出(DFT(a),DFT(b))

    不妨设:

    [P(x)=A(x)+ ext{i}B(x) ag{4.1} ]

    [Q(x)=A(x)- ext{i}B(x) ag{4.2} ]

    接下来我们开始推导公式。注意为了简洁,我们记(X=frac{2 pi jk}{2L}),( ext{conj}(z))表示(z)的共轭复数

    [egin{aligned} DFT(p_k) &=Aleft(omega_{2 L}^{k} ight)+i Bleft(omega_{2 L}^{k} ight) \ &=sum_{j=0}^{2 L-1} a_{j} omega_{2 L}^{j k}+i b_{j} omega_{2 L}^{j k} \ &=sum_{j=0}^{2 L-1}left(a_{j}+i b_{j} ight)(cos X+i sin X) end{aligned}]

    [egin{aligned} DFT(q_k) &=Aleft(omega_{2 L}^{k} ight)-i Bleft(omega_{2 L}^{k} ight) \ &=sum_{j=0}^{2 L-1} a_{j} omega_{2 L}^{j k}-i b_{j} omega_{2 L}^{j k} \ &=sum_{j=0}^{2 L-1}left(a_{j}-i b_{j} ight)(cos X+i sin X) \ &=sum_{j=0}^{2 L-1}left(a_{j} cos X+b_{j} sin X+i sin X-b_{j} cos X ight) \&=operatorname{conj}left(sum_{j=0}^{2 L-1}left(a_{j} cos X+b_{j} sin X ight)-ileft(a_{j} sin X-b_{j} cos X ight) ight)\ &=operatorname{conj}left(sum_{j=0}^{2 L-1}left(a_{j} cos (-X)-b_{j} sin (-X) ight)+ileft(a_{j} sin (-X)+b_{j} cos (-X) ight) ight)\ &=operatorname{conj}left(sum_{j=0}^{2 L-1}left(a_{j}+i b_{j} ight)(cos (-X)+i sin (-X)) ight)\ &=operatorname{conj}left(sum_{j=0}^{2 L-1}left(a_{j}+i b_{j} ight) omega_{2 i}^{-j k} ight)\ &=operatorname{conj}left(sum_{j=0}^{2 L-1}left(a_{j}+i b_{j} ight) omega_{2 L}^{(2 L-k) j} ight)\ &=operatorname{conj}left(p'[2 L-k] ight) end{aligned}]

    也就是说,只要一次DFT算出(DFT(p)),就可以把序列反转再取共轭复数得到(DFT(q)).

    由于DFT是线性变换,

    [DFT(a_k)=frac{DFT(p_k)+DFT(q_k)}{2}=frac{DFT(p_k)+ ext{conj}(DFT(p_j))}{2} ]

    其中(j)(k)翻转后的数,即(j=egin{cases}0,k=0 \ L-k ,k>0 end{cases})

    又由((4.1),(4.2))

    [DFT(a_k)=frac{DFT(p_k)+DFT(q_k)}{2} ag{4.3} ]

    [DFT(b_k)=- ext{i}frac{DFT(p_k)-DFT(q_k)}{2} ag{4.4} ]

    [DFT(a_k)DFT(b_k)= ext{i}frac{{DFT(p_k)}^2-{DFT(q_k)}^2}{4} ag{4.5} ]

    这样我们就可以从(q')推出(a',b'),也就是说一次DFT就能得到(a')(b')了.

    我们一共做了2次长度为(L)的FFT.

    代码(UOJ#34):

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    #define maxn 1000000
    const double pi=acos(-1.0);
    using namespace std; 
    typedef long long ll;
    struct com{
    	double real;
    	double imag;
    	com(){
    		
    	} 
    	com(double _real,double _imag){
    		real=_real;
    		imag=_imag;
    	}
    	com(double x){
    		real=x;
    		imag=0;
    	}
    	void operator = (const com x){
    		this->real=x.real;
    		this->imag=x.imag;
    	}
    	void operator = (const double x){
    		this->real=x;
    		this->imag=0;
    	}
    	friend com operator + (com p,com q){
    		return com(p.real+q.real,p.imag+q.imag);
    	}
    	friend com operator + (com p,double q){
    		return com(p.real+q,p.imag);
    	}
    	void operator += (com q){
    		*this=*this+q;
    	}
    	void operator += (double q){
    		*this=*this+q;
    	}
    	friend com operator - (com p,com q){
    		return com(p.real-q.real,p.imag-q.imag);
    	}
    	friend com operator - (com p,double q){
    		return com(p.real-q,p.imag);
    	}
    	void operator -= (com q){
    		*this=*this-q;
    	}
    	void operator -= (double q){
    		*this=*this-q;
    	}
    	friend com operator * (com p,com q){
    		return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    	}
    	friend com operator * (com p,double q){
    		return com(p.real*q,p.imag*q);
    	} 
    	void operator *= (com q){
    		*this=(*this)*q;
    	}
    	void operator *= (double q){
    		*this=(*this)*q;
    	}
    	friend com operator / (com p,double q){
    		return com(p.real/q,p.imag/q);
    	} 
    	void operator /= (double q){
    		*this=(*this)/q;
    	} 
    	com conj(){
    		return com(real,-imag);
    	}
    	void print(){
    		printf("%lf + %lf i ",real,imag);
    	}
    };
    int rev[maxn+5];
    com w[maxn+5];
    void fft(com *x,int n){
    	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    	for(int len=1;len<n;len*=2){
    		int sz=len*2;
    		for(int l=0;l<n;l+=sz){
    			int r=l+len-1;
    			for(int i=l;i<=r;i++){
    				com tmp=x[i+len];
    				x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
    				x[i]=x[i]+tmp*w[n/sz*(i-l)];
    			}
    		}
    	}
    }
    void mul(ll *a,ll *b,ll *c,int n){
    	static com p[maxn+5],r[maxn+5];
    	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//预处理单位根 
    	for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i]
    	fft(p,n);
    	for(int i=0;i<n;i++){
    		int j=(i>0?(n-i):0);//0的位置需要特判一下
    		com q=p[j];
    		r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子
    	}	
    	fft(r,n);//这里是用了第一篇中提到的反转技巧
    	for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5;
    }
    
    int n,m; 
    ll a[maxn+5],b[maxn+5],c[maxn+5];
    int main(){
    	scanf("%d %d",&n,&m);
    	for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    	for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    	int N=1,L=0;
    	while(N<n+m+1){
    		L++;
    		N*=2;
    	} 
    	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    	mul(a,b,c,N);
    	for(int i=0;i<n+m+1;i++) printf("%lld
    ",c[i]);
    }
    
    

    IDFT的合并

    IDFT的合并是指,对于两个序列(a),(b),我们只通过一次FFT就求出(IDFT(a),IDFT(b))

    IDFT的合并非常简单。
    (r(x)=a(x)+ ext{i}b(x))
    由于IDFT是线性变换
    (IDFT(r(x))=IDFT(a(x))+ ext{i}IDFT(b(x)))
    又因为(a(x))(b(x))都是实数序列,那么(IDFT(r(x)))的实部就是(IDFT(a(x))),虚部就是(IDFT(b(x)))

    形如((A+B)(C+D))的卷积的优化

    在这一节中我们讨论((A(x)+B(x))(C(x)+D(x)))形式的卷积的优化.

    一般的做法是对(A,B,C,D)都做一次DFT,然后按照这个式子直接计算,最后再IDFT回来。需要5次FFT.

    而根据上面的合并技巧,先把(A(x),B(x))合并DFT,(C(x),D(x))合并DFT得到点值表达式.
    由于((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x))
    我们可以直接把点值表达式相乘得到这4个多项式。对于这4个多项式,分成2组合并做IDFT即可。
    总共需要4次FFT.

    大致代码如下:

    void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){
    	static com p[maxn+5],q[maxn+5];
    	static com r[maxn+5],s[maxn+5];
    	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
    	for(int i=0;i<n;i++){
    		p[i]=com(a[i],b[i]);//打包A,B 
    		q[i]=com(c[i],d[i]);//打包C,D 
    	}
    	fft(p,n);
    	fft(q,n);
    	for(int i=0;i<n;i++){
    		int j=(i==0?0:n-i);
    		//得到DFT(A),DFT(B),DFT(C),DFT(D) 
    		com da=(p[i]+p[j].conj())*0.5;
    		com db=(p[i]-p[j].conj())*com(0,-0.5);
    		com dc=(q[i]+q[j].conj())*0.5;
    		com dd=(q[i]-q[j].conj())*com(0,-0.5);
    		r[j]=da*dc+da*dd*com(0,1);//打包AC,AD 
    		s[j]=db*dc+db*dd*com(0,1); //打包BC,BD 
    	}
    	fft(r,n);
    	fft(s,n);
    	for(int i=0;i<n;i++){
    		ll ac,ad,bc,bd; 
    		ac=(ll)(r[i].real/n+0.5);
            ad=(ll)(r[i].imag/n+0.5);
            bc=(ll)(s[i].real/n+0.5);
            bd=(ll)(s[i].imag/n+0.5);
            ans[i]=ac+ad+bc+bd;
    	}
    }
    
    

    卷积的终极优化

    上述优化中我们只用到了DFT的思想。现在我们利用FFT的思想继续优化

    同样拆分奇偶项,(A(x)=A_0(x^2)+xA_1(x^2))

    [egin{aligned} A(x)B(x)&=(A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\ &=A_0(x^2)B_0(x^2)+x(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))+x^2A_1(x^2)B_1(x^2) end{aligned} ag{4.6}]

    我们只需要知道上式中(x^0,x^1,x^2)的系数
    发现(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))是奇数项的系数,(A_0(x^2)B_0(x^2))(A_1(x^2)B_1(x^2))是偶数项的系数,而偶数项的两个东西都可以看成一个关于(x^2)的多项式。

    我们先优化DFT的过程,观察((4.6))式的乘积形式((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))).

    我们发现,这个形式和上一节的((A+B)(C+D))很像,可以类似地优化。
    (p_k={a_0}_k+ ext{i}{a_1}_k,q_k={b_0}_k+ ext{i}{b_1}_k)

    然后合并IDFT,再设两个辅助多项式

    [G(x)=DFT(A_0(x))cdot DFT(B_0(x))+omega_L^k DFT(A_1(x)) DFT(B_1(x)) ]

    (注意我们把(x^2)换元成(x),做DFT的时候要乘上单位根)

    [F(x)=DFT(A_0(x))cdot DFT(B_1(x))+ DFT(A_1(x)) DFT(B_0(x)) ]

    那么我们只需要计算出(IDFT(G(x)))(IDFT(F(x)))

    (R(x)=G(x)+mathrm{i} F(x))
    那么因为IDFT是线性变换,(IDFT(R(x))=IDFT(G(x))+mathrm{i} IDFT(F(x)))
    (IDFT的线性性这里不做证明,容易发现两个点值表达式相加再IDFT回来,显然系数也会相加)

    显然这两个多项式IDFT的结果是实数。故我们只要求出(IDFT(R(x))),每一项系数的实部就是偶数项系数(G(x)),虚部就是奇数项系数(F(x))

    我们再考虑把合并DFT弄进去,即式((4.3)(4.4)(4.5))

    接下来我们尝试用(DFT(p_k),DFT(q_k))来表示(R(x)=G(x)+ ext{i}F(x)),为了推导简洁,我们省略(DFT)不写

    [egin{aligned} g_k&=frac {p_k+ ext{conj}(p_j)}{2}cdot frac {q_k+ ext{conj}(q_j)}{2}+omega_L^kcdot frac {p_k- ext{conj}(p_j)}{-2i}cdot frac {q_k- ext{conj}(q_j)}{-2i}\ &=frac 1 4 [(p_k+ ext{conj}(p_j))cdot(q_k+ ext{conj}(q_j))-omega_L^kcdot(p_k- ext{conj}(p_j))cdot(q_k- ext{conj}(q_j))]\ \ f_k&=frac {p_k+ ext{conj}(p_j)} 2 cdot frac{q_k- ext{conj}(q_j)}{-2}i+frac {q_k+ ext{conj}(q_j)} 2 cdot frac{p_k- ext{conj}(p_j)}{-2}i\ &=frac i{-4}[2cdot p_kcdot q_k-2cdot ext{conj}(p_j)cdot ext{conj}(q_j)] end{aligned}]

    那么

    [egin{aligned} g_k+ ext{i} f_k&=frac 1 4 [(p_k+ ext{conj}(p_j))cdot(q_k+ ext{conj}(q_j))-w_L^kcdot(p_k- ext{conj}(p_j))cdot(q_k- ext{conj}(q_j))-2cdot p_kcdot q_k+2 ext{conj}(p_jcdot q_j)]\ &=frac 1 4 [-(p_k- ext{conj}(p_j))cdot(q_k- ext{conj}(q_j))+2cdot (p_kcdot q_k+ ext{conj}(p_jcdot q_j))\ &-w_L^kcdot(p_k- ext{conj}(p_j))cdot(q_k- ext{conj}(q_j))+2cdot p_kcdot q_k-2cdot ext{conj}(p_jcdot q_j)]\ &=q_kcdot p_k-frac 1 4[(1+w_L^k)cdot (p_k- ext{conj}(p_j))cdot(q_k- ext{conj}(q_j))]\ end{aligned}]

    和上一节的((A+B)(C+D))不同,我们只用了3次长度为(L/2)的FFT,就求出了答案,这是由于FFT本身的性质。因为长度缩减了一半,我们不妨称它为(1.5)次FFT.

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    #define maxn 1000000
    const double pi=acos(-1.0);
    using namespace std; 
    typedef long long ll;
    struct com{
    	double real;
    	double imag;
    	com(){
    		
    	} 
    	com(double _real,double _imag){
    		real=_real;
    		imag=_imag;
    	}
    	com(double x){
    		real=x;
    		imag=0;
    	}
    	void operator = (const com x){
    		this->real=x.real;
    		this->imag=x.imag;
    	}
    	void operator = (const double x){
    		this->real=x;
    		this->imag=0;
    	}
    	friend com operator + (com p,com q){
    		return com(p.real+q.real,p.imag+q.imag);
    	}
    	friend com operator + (com p,double q){
    		return com(p.real+q,p.imag);
    	}
    	void operator += (com q){
    		*this=*this+q;
    	}
    	void operator += (double q){
    		*this=*this+q;
    	}
    	friend com operator - (com p,com q){
    		return com(p.real-q.real,p.imag-q.imag);
    	}
    	friend com operator - (com p,double q){
    		return com(p.real-q,p.imag);
    	}
    	void operator -= (com q){
    		*this=*this-q;
    	}
    	void operator -= (double q){
    		*this=*this-q;
    	}
    	friend com operator * (com p,com q){
    		return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    	}
    	friend com operator * (com p,double q){
    		return com(p.real*q,p.imag*q);
    	} 
    	void operator *= (com q){
    		*this=(*this)*q;
    	}
    	void operator *= (double q){
    		*this=(*this)*q;
    	}
    	friend com operator / (com p,double q){
    		return com(p.real/q,p.imag/q);
    	} 
    	void operator /= (double q){
    		*this=(*this)/q;
    	} 
    	com conj(){
    		return com(real,-imag);
    	}
    	void print(){
    		printf("%lf + %lf i ",real,imag);
    	}
    };
    int rev[maxn+5];
    com w[maxn+5];
    void fft(com *x,int n){
    
    	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    	for(int len=1;len<n;len*=2){
    		int sz=len*2;
    		for(int l=0;l<n;l+=sz){
    			int r=l+len-1;
    			for(int i=l;i<=r;i++){
    				com tmp=x[i+len];
    				x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
    				x[i]=x[i]+tmp*w[n/sz*(i-l)];
    			}
    		}
    	}
    }
    void mul(ll *a,ll *b,ll *c,int n){
    	static com p[maxn+5],q[maxn+5],r[maxn+5];
    	for(int i=0;i<n;i++){//合并做DFT
    		if(i%2==1){
    			p[i/2].imag=a[i];
    			q[i/2].imag=b[i]; 
    		}else{
    			p[i/2].real=a[i];
    			q[i/2].real=b[i];
    		}
    	}
    	n/=2;
    	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
    	fft(q,n);
    	fft(p,n);
    	for(int i=0;i<n;i++){
    		int j=(i>0?(n-i):0);
    		r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25;
    	}	
    	fft(r,n);
    	for(int i=0;i<n;i++){
    		c[i*2]=r[i].real/n+0.5;
    		c[i*2+1]=r[i].imag/n+0.5; 
    	}
    }
    
    int n,m; 
    ll a[maxn+5],b[maxn+5],c[maxn+5];
    int main(){
    	scanf("%d %d",&n,&m);
    	for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    	for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    	int N=1,L=0;
    	while(N<=n+m+1){
    		L++;
    		N*=2;
    	} 
    	for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意这里的rev数组是对N/2做的,L要-1 
    	mul(a,b,c,N);
    	for(int i=0;i<n+m+1;i++) printf("%lld
    ",c[i]);
    }
    
    
    

    任意模数NTT

    三模数NTT

    这是任意模数NTT的算法中最好理解的一种,它基于中国剩余定理。

    定理5.1(m_1,m_2 ,dots m_n)两两互质,则对于(forall a_1,a_2 dots a_n)同余方程组

    [egin{cases} x equiv a_1 (mod m_1) \ x equiv a_2 (mod m_2) \ dots \ x equiv a_n (mod m_n)end{cases} ]

    有整数解解,且可以用如下方式构造解

    1. (M=prod_{i=1}^n m_i,M_i=frac{M}{m_i})
    2. (M_i^{-1})为模(m_i)意义下(M_i)的逆元
    3. 则该方程组在模(M)意义下的唯一解为(x=sum_{i=1}^n a_iM_iM_i^{-1}) ,方程组的通解可以表示为(x+kM(k in mathbb{Z}))

    这就是著名的中国剩余定理(Chinese Reminder Theorem,CRT)

    证明:

    对于(k eq i),(a_iM_iM_i^{-1} mod m_k=0), 而根据逆元的定义,(a_iM_iM_i^{-1} mod m_i =a_i). 再代入到(sum_{i=1}^n a_iM_iM_i^{-1}),原方程组成立。

    回到任意模数NTT问题

    (M)意义下长度为(n)的序列做卷积,最大值可以到(n^2M).一般的题目中(n leq 10^5,Mleq 10^{9}),那么结果会到(10^{23})级别。用long double等存储会丢失精度。那么我们可以选三个乘起来大于(10^{23})的NTT模数998244353,1004535809,469762049(选这三个模数的好处是他们的原根都是3,所以NTT部分写起来比较简洁)。然后分别在这三个模数的意义下做卷积。最后考虑把答案合并,我们只考虑某一位上的值(ans),容易写出:

    [egin{cases} ans=a_1( mod m_1) (5.2)\ans=a_2( mod m_2)(5.3)\ans=a_3( mod m_3) (5.4)end{cases} ]

    显然(m_1,m_2,m_3)互质,那么我们可以利用中国剩余定理直接合并。但是,直接合并把三个模数乘起来的时候会超出long long的范围。注意到两个模数相乘还是在long long范围内的,可以两两合并,具体方法如下,

    (inv(a,m))表示(a)在模(m)下的逆元.根据CRT合并((5.2)(5.3))有:

    [ans equiv a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2)(mod m_1m_2) ag{5.5} ]

    不妨设(ans=km_1m_2+r),根据(5.4)

    (ans=km_1 m_2+r=q m_3+a_3 ag{5.6}),

    在模 (m_3) 意义下有

    (km_1 m_2+r equiv a_3 (mod m_3) ag{5.7})

    因此(k=(a_3-r_2)inv(m_1m_2,m_3) (mod m_3)),不妨设(k=dm_3+e),代入(5.6)

    [ans=dm_1m_2m_3+em_1m_2+r ]

    由于(m_1m_2m_3>ans),所以(d=0),也就是说,(ans=em_1m_2+r),其中(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3))

    const ll mm=m1*m2;
    inline ll inv(ll a,ll m);
    ll mul(ll a,ll b,ll m);//要用按位乘防止溢出
    ll CRT(ll a1,ll a2,ll a3){
        ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
        ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
        return ((e%C)*(mm%C)%C+r%C)%C;
    }
    

    完整代码(LuoguP4245 【模板】任意模数NTT)

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #define m1 998244353ll
    #define m2 1004535809ll
    #define m3 469762049ll
    #define G 3
    #define maxn 1048576
    using namespace std; 
    typedef long long ll;
    const ll mm=m1*m2;
    ll C;
    ll fast_pow(ll x,ll k,ll m){
    	ll ans=1;
    	while(k){
    		if(k&1) ans=ans*x%m;
    		x=x*x%m;
    		k>>=1; 
    	}
    	return ans;
    }
    inline ll inv(ll a,ll m){
    	return fast_pow(a%m,m-2,m); //一定要取模m 
    } 
    
    ll mul(ll a,ll b,ll m){
    	ll ans=0;
    	while(b){
    		if(b&1) ans=(ans+a)%m;
    		a=(a+a)%m;
    		b>>=1;
    	}
    	return ans;
    }
    ll CRT(ll a1,ll a2,ll a3){
    	//[Warning]You are not expected to understand this.
        ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
        ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
        return ((e%C)*(mm%C)%C+r%C)%C;
    }
    
    int n,m,N,L;
    int rev[maxn+5];
    void NTT(ll *x,int n,int type,ll mod){
    	ll invG=inv(G,mod); 
    	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); 
    	for(int len=1;len<n;len*=2){
    		int sz=len*2;
    		ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod);
    		for(int l=0;l<n;l+=sz){
    			int r=l+len-1;
    			ll gnk=1;
    			for(int i=l;i<=r;i++){
    				ll tmp=x[i+len];
    				x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
    				x[i]=(x[i]+gnk*tmp%mod)%mod;
    				gnk=gnk*gn1%mod; 
    			}
    		}
    	} 
    	if(type==-1){
    		ll invn=inv(n,mod);
    		for(int i=0;i<n;i++) x[i]=x[i]*invn%mod; 
    	}
    } 
    void fmul(ll *a,ll *b,ll *ans,int n,ll mod){
    	static ll ta[maxn+5],tb[maxn+5];
    	for(int i=0;i<n;i++) ta[i]=a[i];
    	for(int i=0;i<n;i++) tb[i]=b[i];
    	NTT(ta,n,1,mod);
    	if(a!=b) NTT(tb,n,1,mod);
    	for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod;
    	NTT(ans,n,-1,mod);
    }
    
    ll a[maxn+5],b[maxn+5],c[3][maxn+5];
    int main(){
    	scanf("%d %d %lld",&n,&m,&C);
    	for(int i=0;i<=n;i++){
    		scanf("%lld",&a[i]);
    		a[i]%=C;
    	}
    	for(int i=0;i<=m;i++){
    		scanf("%lld",&b[i]);
    		b[i]%=C;
    	}
    	N=1,L=0;
    	while(N<n+m+1){
    		N*=2;
    		L++;
    	}
    	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    	fmul(a,b,c[0],N,m1);
    	fmul(a,b,c[1],N,m2);
    	fmul(a,b,c[2],N,m3);
    	for(int i=0;i<n+m+1;i++){
    		printf("%lld ",CRT(c[0][i],c[1][i],c[2][i]));
    	}
    }
    

    容易发现,三模数NTT需要9次FFT,不是很优秀

    拆系数FFT

    我们之前讨论的优化都是针对FFT的,那不妨尝试用FFT解决任意模数NTT

    最简单的想法是不取模,FFT完再取模。但是上文提到数值过大,long double会丢失精度。
    int128是一个方法,但在OI比赛中不一定能使用。所以需要拆系数。

    (M_0=[sqrt{M}])

    [egin{aligned} a_i=k[a_i]M_0+b[a_i]\ b_i=k[b_i]M_0+b[b_i]end{aligned}]

    相当于把模数换成(M_0),降低大小。
    代入对应的多项式

    [egin{aligned}A(x)=K_a(x)M_0+B_a(x)\ B(x)=K_b(x)M_0+B_b(x)\ A(x)B(x)=K_a(x)K_b(x)M_0^2+(K_a(x)B_b(x)+K_b(x)B_a(x))M_0+B_a(x)B_b(x) end{aligned}]

    这不就是我们提到的((A+B)(C+D))形的卷积吗?
    由于(k,b)都不超过(2^{15}),于是就不容易被卡精度了。实际操作中我们不必取(M_0=sqrt{M}),直接取(M_0=2^{15})即可。这样取模运算可以换成位运算,进一步减小常数。

    代码(LuoguP4245 【模板】任意模数NTT)

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    #define maxn 1000000
    const double pi=acos(-1.0);
    using namespace std; 
    typedef long long ll;
    struct com{
    	double real;
    	double imag;
    	com(){
    		
    	} 
    	com(double _real,double _imag){
    		real=_real;
    		imag=_imag;
    	}
    	com(double x){
    		real=x;
    		imag=0;
    	}
    	void operator = (const com x){
    		this->real=x.real;
    		this->imag=x.imag;
    	}
    	void operator = (const double x){
    		this->real=x;
    		this->imag=0;
    	}
    	friend com operator + (com p,com q){
    		return com(p.real+q.real,p.imag+q.imag);
    	}
    	friend com operator + (com p,double q){
    		return com(p.real+q,p.imag);
    	}
    	void operator += (com q){
    		*this=*this+q;
    	}
    	void operator += (double q){
    		*this=*this+q;
    	}
    	friend com operator - (com p,com q){
    		return com(p.real-q.real,p.imag-q.imag);
    	}
    	friend com operator - (com p,double q){
    		return com(p.real-q,p.imag);
    	}
    	void operator -= (com q){
    		*this=*this-q;
    	}
    	void operator -= (double q){
    		*this=*this-q;
    	}
    	friend com operator * (com p,com q){
    		return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
    	}
    	friend com operator * (com p,double q){
    		return com(p.real*q,p.imag*q);
    	} 
    	void operator *= (com q){
    		*this=(*this)*q;
    	}
    	void operator *= (double q){
    		*this=(*this)*q;
    	}
    	friend com operator / (com p,double q){
    		return com(p.real/q,p.imag/q);
    	} 
    	void operator /= (double q){
    		*this=(*this)/q;
    	} 
    	com conj(){
    		return com(real,-imag);
    	}
    	void print(){
    		printf("(%lf,%lf)
    ",real,imag);
    	}
    };
    int rev[maxn+5];
    com w[maxn+5];
    void fft(com *x,int n){
    	for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
    	for(int len=1;len<n;len*=2){
    		int sz=len*2;
    		for(int l=0;l<n;l+=sz){
    			int r=l+len-1;
    			for(int i=l;i<=r;i++){
    				com tmp=x[i+len];
    				x[i+len]=x[i]-tmp*w[n/sz*(i-l)];
    				x[i]=x[i]+tmp*w[n/sz*(i-l)];
    			}
    		}
    	}
    }
    ll mod; 
    void mul(ll *ina,ll *inb,ll *inc,int n){
    	static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5];
    	static com p[maxn+5],q[maxn+5];
    	static com r[maxn+5],s[maxn+5];
    	for(int i=0;i<n;i++){
    		ina[i]=(ina[i]+mod)%mod;
    		inb[i]=(inb[i]+mod)%mod;
    		a[i]=ina[i]>>15;
    		b[i]=ina[i]&((1<<15)-1);
    		c[i]=inb[i]>>15;
    		d[i]=inb[i]&((1<<15)-1);
    	}
    	for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
    	for(int i=0;i<n;i++){
    		p[i]=com(a[i],b[i]);//打包A,B 
    		q[i]=com(c[i],d[i]);//打包C,D 
    	}
    	fft(p,n);
    	fft(q,n);
    	for(int i=0;i<n;i++){
    //		p[i].print();
    		int j=(i==0?0:n-i);
    		//得到DFT(A),DFT(B),DFT(C),DFT(D) 
    		com da=(p[i]+p[j].conj())*0.5;
    		com db=(p[i]-p[j].conj())*com(0,-0.5);
    		com dc=(q[i]+q[j].conj())*0.5;
    		com dd=(q[i]-q[j].conj())*com(0,-0.5);
    		r[j]=da*dc+da*dd*com(0,1);//打包AC,AD 
    		s[j]=db*dc+db*dd*com(0,1); //打包BC,BD 
    	}
    	fft(r,n);
    	fft(s,n);
    	for(int i=0;i<n;i++){
    		ll ac,ad,bc,bd; 
    		ac=(ll)(r[i].real/n+0.5)%mod;
            ad=(ll)(r[i].imag/n+0.5)%mod;
            bc=(ll)(s[i].real/n+0.5)%mod;
            bd=(ll)(s[i].imag/n+0.5)%mod;
            inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod;
    	}
    }
    
    int n,m; 
    ll a[maxn+5],b[maxn+5],c[maxn+5];
    int main(){
    	scanf("%d %d %lld",&n,&m,&mod);
    	for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    	for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    	int N=1,L=0;
    	while(N<=n+m+1){
    		L++;
    		N*=2;
    	} 
    	for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    	mul(a,b,c,N);
    	for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]);
    }
    
    
  • 相关阅读:
    学习进度条2
    构建之法阅读笔记04
    团队项目——班级派发布视频
    四则运算3
    软件工程结对作业02四则运算四
    构建之法阅读笔记01
    构建之法阅读笔记02
    学习进度条4
    学习进度条1
    返回一个二维整数数组中最大联通子数组的和
  • 原文地址:https://www.cnblogs.com/birchtree/p/12470386.html
Copyright © 2011-2022 走看看