再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
写在前面
为了不使篇幅过长,预计将把基于论文的学习笔记分为三部分:
- DFT,IDFT,FFT的定义,实现与证明:快速傅里叶变换(FFT)学习笔记(其一)
- NTT的实现与证明:快速傅里叶变换(FFT)学习笔记(其二)
- 任意模数NTT与FFT的优化技巧
一些约定
- ([p(x)]=egin{cases}1,p(x)为真 \ 0,p(x)为假 end{cases})
- 本文中序列的下标从0开始
- 若(s)是一个序列,(|s|)表示(s)的长度
- 若大写字母如(F(x))表示一个多项式,那么对应的小写字母如(f)表示多项式的每一项系数,即(F(x)=sum_{i=0}^{n-1} f_ix^i)
循环卷积
DFT卷积的本质
考虑在(其一)中提到的卷积的定义式。
我们一般做FFT时忽略了式子中的(mod),其实它是在(mod 2^q)的意义下的循环卷积,只是因为(|a|,|b|,|c|<2^q),所以取不取模都没什么影响。
如果序列长度(n)是2的整数次幂,那么直接做就可以了。
如果序列长度(n)不是2的整数次幂考虑暴力的做法:先做一次普通FFT,再把(c_{k+n})加到(c_k)上。但是这样在做多次FFT时就必须一次一次做,比如多项式快速幂。下面给出了一种在(O(n log n))的时间内实现任意长度循环卷积的算法:Bluestein’s Algorithm
Bluestein’s Algorithm
注:原论文的推导可能有误
考虑DFT的式子
不妨设
(x_j=a_j omega_n^{frac{j^2}{2}}=a_j(cosfrac{j^2pi}{n}+ ext{i}sin{frac{j^2pi}{n}}))
(y_j=omega_n^{-frac{j^2}{2}}= cos frac{pi j^2}{n}- ext{i}sin frac{pi j^2}{n})
那么(a_i'=omega_n^{frac{j^2}{2}}sum_{j=0}^{n-1} x_j y_{i-j})
这已经很类似卷积的形式了,但是注意到(j)的上界是(n-1)而不是(i),(j-i)可能为负数。那么我们把(y)数组的长度扩大到(2n),定义:
(y_j=omega_n^{-frac{(j-n)^2}{2}}= cos frac{pi (j-n)^2}{n}- ext{i}sin frac{pi (j-n)^2}{n}).
这样(j<n)的时候就对应了(j-i)为负数的情形,(jgeq n)就对应了(j-i)为正的情形。然后对(x)和(y)用一般的FFT,最后的答案存储在(i+n)的位置上,也就是说真正的(a'_i)实际上对应了乘积结果的((x cdot y)_{i+n})
这样,我们就只做了3次FFT就求出了任意长度循环DFT。逆变换同理,只是换成共轭复数。注意到在上述的推导中我们没有用到单位根(omega)的任何性质,因此这里的(omega)可以换成任意复数(z),这样的变换称为Chirp Z-Transform,CZT.可见,CZT实际上是DFT的广义形式。
代码实现:
//com是手写复数类,省略
void fft(com *x,int *rev,int n,int type){
//为节约篇幅,fft部分省略,x为系数序列,rev为反转数组,n为长度,type=1表示DFT,type=-1表示IDFT
}
void bluestein(com *a,int n,int type){
//a为系数序列,n为长度,type=1表示DFT,type=-1表示IDFT
static com x[maxn*4+5],y[maxn*4+5];
static int rev[maxn*4+5];
memset(x,0,sizeof(x));
memset(y,0,sizeof(y));
//FFT前的预处理
int N=1,L=0;
while(N<n*4){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
//x[i],y[i]的定义见上式
for(int i=0;i<n;i++) x[i]=com(cos(pi*i*i/n),type*sin(pi*i*i/n))*a[i];
for(int i=0;i<n*2;i++) y[i]=com(cos(pi*(i-n)*(i-n)/n),-type*sin(pi*(i-n)*(i-n)/n));
fft(x,rev,N,1);
fft(y,rev,N,1);
for(int i=0;i<N;i++) x[i]*=y[i];
fft(x,rev,N,-1);
for(int i=0;i<n;i++){
a[i]=x[i+n]*com(cos(pi*i*i/n),type*sin(pi*i*i/n));//记得乘上常数
if(type==-1) a[i]/=n;//一定记得除以n,因为做一次Bluestein相当于一次FFT,IFFT最后要除n,这里也要除n
}
}
例题
[POJ 2821]TN's Kindom III(任意长度循环卷积的Bluestein算法)
分治FFT
一般我们用FFT的时候,序列的所有元素都已知。但是,如果序列本身是根据卷积定义的,就无法直接套FFT
举一个最简单的例子(f_i =sum_{j=1}^i f_{i-j}g_j).其中(g)给定,求(f). 由于我们卷积的时后后面的数基于前面的数,无法快速计算,时间复杂度退化到(O(n^2)). (虽然这个式子可以用(其四)中将会提到的多项式求逆解决,但是分治FFT更通用,可以处理很复杂的式子)
考虑分治: 设当前分治区间为([l,r]),假设我们求出了([l,mid])的答案,那么可以求出这些点对([mid+1,r])的影响。那么右半边的点(x in [mid+1,r])得到的贡献是(Delta_x=sum_{i=l}^{mid} f_i g_{x-i}).只需要把下标偏移一下(如([l,mid])偏移成([0,mid-l]),就是一个卷积的形式,可以运用FFT或NTT计算,计算完之后,把答案累加到数组上.
伪代码如下:
poly f,g;//上述的f,g
procedure calc(L,mid,R){
for i in [L,mid] : a[i-L] <- f[i]//下标偏移
for i in [1,R-L] : b[i-1] <- g[i]
a <- mul(a,b);//fft或ntt做多项式乘法
for i in [mid+1,R] f[i] <- f[i]+a[i-l-1]//累加贡献
}
procedure solve(l,mid){
if(l==r) return;
mid <- (l+r)/2
solve(l,mid);
calc(l,mid,r);
solve(mid+1,r)
}
时间复杂度分析:
(T(n)=2T(frac{n}{2})+n log_2n), 总复杂度(Theta(n log^2n))
下面是基于NTT的模板代码(Luogu 4721)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 300000
#define G 3
#define invG 332748118
#define inv2 499122177
#define mod 998244353
using namespace std;
typedef long long ll;
inline ll fast_pow(ll x,ll k){
ll ans=1;
while(k){
if(k&1) ans=ans*x%mod;
x=x*x%mod;
k>>=1;
}
return ans;
}
inline ll inv(ll x){
return fast_pow(x,mod-2);
}
void NTT(ll *x,int n,int type){
static int rev[maxn+5];
int tn=1;
int k=0;
while(tn<n){
tn*=2;
k++;
}
for(int i=0;i<tn;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
for(int i=0;i<n;i++){
if(i<rev[i]) swap(x[i],x[rev[i]]);
}
for(int len=1;len<n;len*=2){
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
for(int l=0;l<n;l+=sz){
int r=l+len-1;
ll gnk=1;
for(int i=l;i<=r;i++){
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1){
int invsz=inv(n);
for(int i=0;i<n;i++) x[i]=x[i]*invsz%mod;
}
}
void mul(ll *a,ll *b,ll *ans,int sz){
NTT(a,sz,1);
NTT(b,sz,1);
for(int i=0;i<sz;i++) ans[i]=a[i]*b[i]%mod;
NTT(ans,sz,-1);
}
void cdq_divide(ll *f,ll *g,int l,int r){
static ll tmpa[maxn+5],tmpb[maxn+5];
if(l==r) return;
int mid=(l+r)>>1;
cdq_divide(f,g,l,mid);
int tn=1,k=0;
while(tn<r-l){
k++;
tn*=2;
}
for(int i=0;i<tn;i++) tmpa[i]=tmpb[i]=0;
for(int i=l;i<=mid;i++) tmpa[i-l]=f[i];
for(int i=1;i<=r-l;i++) tmpb[i-1]=g[i];
mul(tmpa,tmpb,tmpa,tn);
for(int i=mid+1;i<=r;i++) f[i]=(f[i]+tmpa[i-l-1])%mod;
cdq_divide(f,g,mid+1,r);
}
int n;
ll f[maxn+5],g[maxn+5];
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++) scanf("%lld",&g[i]);
f[0]=1;
cdq_divide(f,g,0,n-1);
for(int i=0;i<n;i++) printf("%lld ",f[i]);
}
容易发现,许多dp方程都有分治FFT的形式。对于此类dp方程,我们可以用分治FFT将转移复杂度由(O(n^2))降到(O(n log^2 n))
例题
[Codeforces 553E]Kyoya and Train(期望DP+Floyd+分治FFT)
FFT的弱常数优化
下面介绍一些优化FFT的常数的技巧。虽然这些技巧都只是对FFT的一些小优化,但是在某些题目中优化效果极其明显。
复杂算式中减少FFT次数
如果我们要计算一个复杂的多项式,如(A(x)=B(x)C(x)+D(x)E(x))
最简单的方法是分别计算(B(x)C(x))和(D(x)E(x)),这样需要做6次FFT. 但是如果先对(B,C,D,E)做DFT,然后直接用点值表达式计算(a_i=b_ic_i+d_ie_i),再把(a)IDFT回去。这样只需要做5次FFT,且多项式越复杂,这样的常数就越优秀。
例题
[BZOJ 3771] Triple(FFT+容斥原理+生成函数)
利用循环卷积
考虑对于两个长度为(n)的序列(a,b),计算它们的卷积(c)的第(0.5n)项到第(1.5n)项。传统的方法是补0扩充到(2n)的序列。但是因为FFT求得实际上是我们已经提到过的循环卷积,所以如果只补0到(1.5n)(上取整),对第(0.5n)项到第(1.5n)项无影响
在基于牛顿迭代的算法中,能起到较明显的优化作用。会在(其四)中详细介绍这些算法。
小范围暴力
由于FFT的常数较大。在数据范围较小的时候甚至不如(O(n^2))的暴力卷积的优秀。因此在做多次FFT和分治FFT的时候,如果当前的序列长度较小,可以采用暴力算法。
例题
[BZOJ 3509] [CodeChef] COUNTARI (FFT+分块)
快速幂乘法次数的优化
这个东西实际上比较鸡肋。因为多项式快速幂可以通过多项式(ln)和(exp)优化到(O(n log n)).但是为了应对考场上时间不够的情况,我们来考虑如何通过简单的实现来减少(O(n log^2n))的倍增快速幂的复杂度。
倍增法的思路是根据前面算过的乘积快速算出当前的乘积,如(1 o 2 o 4 o 8).最坏情况下需要(2 log_2n+C)次乘法。但这并不是下界。我们定义additional chain为一条链,最开始是1,后一个数减前一个数的差是链上这个是前面的某一个数。例如(1 o 2 o 4 o 6).(6-4=2)在前面出现过,(4-2=2)在前面出现过。那么根据这条additional chain计算6次幂的时候,可以从1次幂出发,用1次幂乘1次幂得到2次幂,再乘2次幂得到4次幂,再乘2次幂得到6次幂。
很可惜,对于数(k)求出得到(k)的最短additional chain是NP-hard的。但是有很好的近似算法。近似算法基于BFS。每次我们对于队头的数(x),枚举它对应的additional chain中的数(y),如果(x+y)还没有访问过那么将其入队,并将(x)对应的链后面接上(x+y). 这个预处理是(O(k))的,且对快速幂的常数优化很显著。
如果(k)很大,比如(10^{10000}),可以采用十进制快速幂。但是用Method of Four Russians(俗称四毛子算法),可以将乘法次数减少到(log_2n+O(frac{log n}{log log n})).具体方法见2017年国家集训队论文《非常规大小分块算法初探》
FFT的强常数优化
FFT的强常数优化一般是通过减少FFT次数来实现的
在这一节中,我们记(DFT(A(x)))表示多项式(A(x))(或序列)做DFT之后的结果,(IDFT(A(x)))同理
我们现在考虑最常见的一个模型:给出两个长度为(n+1)和(m+1)的多项式(A(x),B(x)),我们要计算他们的线性卷积。假设长度已经补齐为第一个大于(n+m+1)的2的整数幂(L)。
显然直接搞需要3次长度为(L)的FFT。毒瘤的Vladimir Smykalov在cf上最先给出了这个问题的优化算法。
DFT的合并
DFT的合并是指,对于两个序列(a),(b),我们只通过一次FFT就求出(DFT(a),DFT(b))
不妨设:
接下来我们开始推导公式。注意为了简洁,我们记(X=frac{2 pi jk}{2L}),( ext{conj}(z))表示(z)的共轭复数
也就是说,只要一次DFT算出(DFT(p)),就可以把序列反转再取共轭复数得到(DFT(q)).
由于DFT是线性变换,
其中(j)为(k)翻转后的数,即(j=egin{cases}0,k=0 \ L-k ,k>0 end{cases})
又由((4.1),(4.2))式
这样我们就可以从(q')推出(a',b'),也就是说一次DFT就能得到(a')和(b')了.
我们一共做了2次长度为(L)的FFT.
代码(UOJ#34):
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("%lf + %lf i ",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
void mul(ll *a,ll *b,ll *c,int n){
static com p[maxn+5],r[maxn+5];
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));//预处理单位根
for(int i=0;i<n;i++) p[i]=com(a[i],b[i]);//p[i]=a[i]+ib[i]
fft(p,n);
for(int i=0;i<n;i++){
int j=(i>0?(n-i):0);//0的位置需要特判一下
com q=p[j];
r[j]=(p[i]*p[i]-q.conj()*q.conj())*com(0,-0.25);//按照上面的式子
}
fft(r,n);//这里是用了第一篇中提到的反转技巧
for(int i=0;i<n;i++) c[i]=r[i].real/n+0.5;
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<n+m+1){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld
",c[i]);
}
IDFT的合并
IDFT的合并是指,对于两个序列(a),(b),我们只通过一次FFT就求出(IDFT(a),IDFT(b))
IDFT的合并非常简单。
设(r(x)=a(x)+ ext{i}b(x))
由于IDFT是线性变换
(IDFT(r(x))=IDFT(a(x))+ ext{i}IDFT(b(x)))
又因为(a(x))和(b(x))都是实数序列,那么(IDFT(r(x)))的实部就是(IDFT(a(x))),虚部就是(IDFT(b(x)))
形如((A+B)(C+D))的卷积的优化
在这一节中我们讨论((A(x)+B(x))(C(x)+D(x)))形式的卷积的优化.
一般的做法是对(A,B,C,D)都做一次DFT,然后按照这个式子直接计算,最后再IDFT回来。需要5次FFT.
而根据上面的合并技巧,先把(A(x),B(x))合并DFT,(C(x),D(x))合并DFT得到点值表达式.
由于((A(x)+B(x))(C(x)+D(x))=A(x)C(x)+A(x)D(x)+B(x)C(x)+B(x)D(x))
我们可以直接把点值表达式相乘得到这4个多项式。对于这4个多项式,分成2组合并做IDFT即可。
总共需要4次FFT.
大致代码如下:
void mul(ll *a,ll *b,ll *c,ll *d,ll *ans,int n){
static com p[maxn+5],q[maxn+5];
static com r[maxn+5],s[maxn+5];
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
for(int i=0;i<n;i++){
p[i]=com(a[i],b[i]);//打包A,B
q[i]=com(c[i],d[i]);//打包C,D
}
fft(p,n);
fft(q,n);
for(int i=0;i<n;i++){
int j=(i==0?0:n-i);
//得到DFT(A),DFT(B),DFT(C),DFT(D)
com da=(p[i]+p[j].conj())*0.5;
com db=(p[i]-p[j].conj())*com(0,-0.5);
com dc=(q[i]+q[j].conj())*0.5;
com dd=(q[i]-q[j].conj())*com(0,-0.5);
r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
}
fft(r,n);
fft(s,n);
for(int i=0;i<n;i++){
ll ac,ad,bc,bd;
ac=(ll)(r[i].real/n+0.5);
ad=(ll)(r[i].imag/n+0.5);
bc=(ll)(s[i].real/n+0.5);
bd=(ll)(s[i].imag/n+0.5);
ans[i]=ac+ad+bc+bd;
}
}
卷积的终极优化
上述优化中我们只用到了DFT的思想。现在我们利用FFT的思想继续优化
同样拆分奇偶项,(A(x)=A_0(x^2)+xA_1(x^2))
我们只需要知道上式中(x^0,x^1,x^2)的系数
发现(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))是奇数项的系数,(A_0(x^2)B_0(x^2))和(A_1(x^2)B_1(x^2))是偶数项的系数,而偶数项的两个东西都可以看成一个关于(x^2)的多项式。
我们先优化DFT的过程,观察((4.6))式的乘积形式((A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))).
我们发现,这个形式和上一节的((A+B)(C+D))很像,可以类似地优化。
令(p_k={a_0}_k+ ext{i}{a_1}_k,q_k={b_0}_k+ ext{i}{b_1}_k)
然后合并IDFT,再设两个辅助多项式
(注意我们把(x^2)换元成(x),做DFT的时候要乘上单位根)
那么我们只需要计算出(IDFT(G(x)))和(IDFT(F(x)))
设(R(x)=G(x)+mathrm{i} F(x))
那么因为IDFT是线性变换,(IDFT(R(x))=IDFT(G(x))+mathrm{i} IDFT(F(x)))
(IDFT的线性性这里不做证明,容易发现两个点值表达式相加再IDFT回来,显然系数也会相加)
显然这两个多项式IDFT的结果是实数。故我们只要求出(IDFT(R(x))),每一项系数的实部就是偶数项系数(G(x)),虚部就是奇数项系数(F(x))
我们再考虑把合并DFT弄进去,即式((4.3)(4.4)(4.5))
接下来我们尝试用(DFT(p_k),DFT(q_k))来表示(R(x)=G(x)+ ext{i}F(x)),为了推导简洁,我们省略(DFT)不写
那么
和上一节的((A+B)(C+D))不同,我们只用了3次长度为(L/2)的FFT,就求出了答案,这是由于FFT本身的性质。因为长度缩减了一半,我们不妨称它为(1.5)次FFT.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("%lf + %lf i ",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];//w(sz,k)=w(n,n/sz*k)
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
void mul(ll *a,ll *b,ll *c,int n){
static com p[maxn+5],q[maxn+5],r[maxn+5];
for(int i=0;i<n;i++){//合并做DFT
if(i%2==1){
p[i/2].imag=a[i];
q[i/2].imag=b[i];
}else{
p[i/2].real=a[i];
q[i/2].real=b[i];
}
}
n/=2;
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
fft(q,n);
fft(p,n);
for(int i=0;i<n;i++){
int j=(i>0?(n-i):0);
r[j]=p[i]*q[i]-(w[i]+1)*(p[i]-p[j].conj())*(q[i]-q[j].conj())*0.25;
}
fft(r,n);
for(int i=0;i<n;i++){
c[i*2]=r[i].real/n+0.5;
c[i*2+1]=r[i].imag/n+0.5;
}
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<=n+m+1){
L++;
N*=2;
}
for(int i=0;i<N/2;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-2));//注意这里的rev数组是对N/2做的,L要-1
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld
",c[i]);
}
任意模数NTT
三模数NTT
这是任意模数NTT的算法中最好理解的一种,它基于中国剩余定理。
定理5.1 若(m_1,m_2 ,dots m_n)两两互质,则对于(forall a_1,a_2 dots a_n)同余方程组
[egin{cases} x equiv a_1 (mod m_1) \ x equiv a_2 (mod m_2) \ dots \ x equiv a_n (mod m_n)end{cases} ]有整数解解,且可以用如下方式构造解
- 设(M=prod_{i=1}^n m_i,M_i=frac{M}{m_i})
- 设(M_i^{-1})为模(m_i)意义下(M_i)的逆元
- 则该方程组在模(M)意义下的唯一解为(x=sum_{i=1}^n a_iM_iM_i^{-1}) ,方程组的通解可以表示为(x+kM(k in mathbb{Z}))
这就是著名的中国剩余定理(Chinese Reminder Theorem,CRT)
证明:
对于(k eq i),(a_iM_iM_i^{-1} mod m_k=0), 而根据逆元的定义,(a_iM_iM_i^{-1} mod m_i =a_i). 再代入到(sum_{i=1}^n a_iM_iM_i^{-1}),原方程组成立。
回到任意模数NTT问题
模(M)意义下长度为(n)的序列做卷积,最大值可以到(n^2M).一般的题目中(n leq 10^5,Mleq 10^{9}),那么结果会到(10^{23})级别。用long double
等存储会丢失精度。那么我们可以选三个乘起来大于(10^{23})的NTT模数998244353,1004535809,469762049(选这三个模数的好处是他们的原根都是3,所以NTT部分写起来比较简洁)。然后分别在这三个模数的意义下做卷积。最后考虑把答案合并,我们只考虑某一位上的值(ans),容易写出:
显然(m_1,m_2,m_3)互质,那么我们可以利用中国剩余定理直接合并。但是,直接合并把三个模数乘起来的时候会超出long long
的范围。注意到两个模数相乘还是在long long
范围内的,可以两两合并,具体方法如下,
记(inv(a,m))表示(a)在模(m)下的逆元.根据CRT合并((5.2)(5.3))有:
不妨设(ans=km_1m_2+r),根据(5.4)有
(ans=km_1 m_2+r=q m_3+a_3 ag{5.6}),
在模 (m_3) 意义下有
(km_1 m_2+r equiv a_3 (mod m_3) ag{5.7})
因此(k=(a_3-r_2)inv(m_1m_2,m_3) (mod m_3)),不妨设(k=dm_3+e),代入(5.6)得
由于(m_1m_2m_3>ans),所以(d=0),也就是说,(ans=em_1m_2+r),其中(r=a_1m_2inv(m_1,m_1m_2)+a_2m_1inv(m_2,m_1m_2),e=(a_3-r_2)inv(m_1m_2,m_3))
const ll mm=m1*m2;
inline ll inv(ll a,ll m);
ll mul(ll a,ll b,ll m);//要用按位乘防止溢出
ll CRT(ll a1,ll a2,ll a3){
ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
return ((e%C)*(mm%C)%C+r%C)%C;
}
完整代码(LuoguP4245 【模板】任意模数NTT)
#include<iostream>
#include<cstdio>
#include<cstring>
#define m1 998244353ll
#define m2 1004535809ll
#define m3 469762049ll
#define G 3
#define maxn 1048576
using namespace std;
typedef long long ll;
const ll mm=m1*m2;
ll C;
ll fast_pow(ll x,ll k,ll m){
ll ans=1;
while(k){
if(k&1) ans=ans*x%m;
x=x*x%m;
k>>=1;
}
return ans;
}
inline ll inv(ll a,ll m){
return fast_pow(a%m,m-2,m); //一定要取模m
}
ll mul(ll a,ll b,ll m){
ll ans=0;
while(b){
if(b&1) ans=(ans+a)%m;
a=(a+a)%m;
b>>=1;
}
return ans;
}
ll CRT(ll a1,ll a2,ll a3){
//[Warning]You are not expected to understand this.
ll r=(mul(a1*m2%mm,inv(m2,m1),mm)+mul(a2*m1%mm,inv(m1,m2),mm))%mm;
ll e=((a3-r)%m3+m3)%m3*inv(mm,m3)%m3;
return ((e%C)*(mm%C)%C+r%C)%C;
}
int n,m,N,L;
int rev[maxn+5];
void NTT(ll *x,int n,int type,ll mod){
ll invG=inv(G,mod);
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz,mod);
for(int l=0;l<n;l+=sz){
int r=l+len-1;
ll gnk=1;
for(int i=l;i<=r;i++){
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1){
ll invn=inv(n,mod);
for(int i=0;i<n;i++) x[i]=x[i]*invn%mod;
}
}
void fmul(ll *a,ll *b,ll *ans,int n,ll mod){
static ll ta[maxn+5],tb[maxn+5];
for(int i=0;i<n;i++) ta[i]=a[i];
for(int i=0;i<n;i++) tb[i]=b[i];
NTT(ta,n,1,mod);
if(a!=b) NTT(tb,n,1,mod);
for(int i=0;i<n;i++) ans[i]=ta[i]*tb[i]%mod;
NTT(ans,n,-1,mod);
}
ll a[maxn+5],b[maxn+5],c[3][maxn+5];
int main(){
scanf("%d %d %lld",&n,&m,&C);
for(int i=0;i<=n;i++){
scanf("%lld",&a[i]);
a[i]%=C;
}
for(int i=0;i<=m;i++){
scanf("%lld",&b[i]);
b[i]%=C;
}
N=1,L=0;
while(N<n+m+1){
N*=2;
L++;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
fmul(a,b,c[0],N,m1);
fmul(a,b,c[1],N,m2);
fmul(a,b,c[2],N,m3);
for(int i=0;i<n+m+1;i++){
printf("%lld ",CRT(c[0][i],c[1][i],c[2][i]));
}
}
容易发现,三模数NTT需要9次FFT,不是很优秀
拆系数FFT
我们之前讨论的优化都是针对FFT的,那不妨尝试用FFT解决任意模数NTT
最简单的想法是不取模,FFT完再取模。但是上文提到数值过大,long double
会丢失精度。
int128
是一个方法,但在OI比赛中不一定能使用。所以需要拆系数。
设(M_0=[sqrt{M}])
相当于把模数换成(M_0),降低大小。
代入对应的多项式
这不就是我们提到的((A+B)(C+D))形的卷积吗?
由于(k,b)都不超过(2^{15}),于是就不容易被卡精度了。实际操作中我们不必取(M_0=sqrt{M}),直接取(M_0=2^{15})即可。这样取模运算可以换成位运算,进一步减小常数。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 1000000
const double pi=acos(-1.0);
using namespace std;
typedef long long ll;
struct com{
double real;
double imag;
com(){
}
com(double _real,double _imag){
real=_real;
imag=_imag;
}
com(double x){
real=x;
imag=0;
}
void operator = (const com x){
this->real=x.real;
this->imag=x.imag;
}
void operator = (const double x){
this->real=x;
this->imag=0;
}
friend com operator + (com p,com q){
return com(p.real+q.real,p.imag+q.imag);
}
friend com operator + (com p,double q){
return com(p.real+q,p.imag);
}
void operator += (com q){
*this=*this+q;
}
void operator += (double q){
*this=*this+q;
}
friend com operator - (com p,com q){
return com(p.real-q.real,p.imag-q.imag);
}
friend com operator - (com p,double q){
return com(p.real-q,p.imag);
}
void operator -= (com q){
*this=*this-q;
}
void operator -= (double q){
*this=*this-q;
}
friend com operator * (com p,com q){
return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real);
}
friend com operator * (com p,double q){
return com(p.real*q,p.imag*q);
}
void operator *= (com q){
*this=(*this)*q;
}
void operator *= (double q){
*this=(*this)*q;
}
friend com operator / (com p,double q){
return com(p.real/q,p.imag/q);
}
void operator /= (double q){
*this=(*this)/q;
}
com conj(){
return com(real,-imag);
}
void print(){
printf("(%lf,%lf)
",real,imag);
}
};
int rev[maxn+5];
com w[maxn+5];
void fft(com *x,int n){
for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1;len<n;len*=2){
int sz=len*2;
for(int l=0;l<n;l+=sz){
int r=l+len-1;
for(int i=l;i<=r;i++){
com tmp=x[i+len];
x[i+len]=x[i]-tmp*w[n/sz*(i-l)];
x[i]=x[i]+tmp*w[n/sz*(i-l)];
}
}
}
}
ll mod;
void mul(ll *ina,ll *inb,ll *inc,int n){
static ll a[maxn+5],b[maxn+5],c[maxn+5],d[maxn+5];
static com p[maxn+5],q[maxn+5];
static com r[maxn+5],s[maxn+5];
for(int i=0;i<n;i++){
ina[i]=(ina[i]+mod)%mod;
inb[i]=(inb[i]+mod)%mod;
a[i]=ina[i]>>15;
b[i]=ina[i]&((1<<15)-1);
c[i]=inb[i]>>15;
d[i]=inb[i]&((1<<15)-1);
}
for(int i=0;i<n;i++) w[i]=com(cos(2*pi*i/n),sin(2*pi*i/n));
for(int i=0;i<n;i++){
p[i]=com(a[i],b[i]);//打包A,B
q[i]=com(c[i],d[i]);//打包C,D
}
fft(p,n);
fft(q,n);
for(int i=0;i<n;i++){
// p[i].print();
int j=(i==0?0:n-i);
//得到DFT(A),DFT(B),DFT(C),DFT(D)
com da=(p[i]+p[j].conj())*0.5;
com db=(p[i]-p[j].conj())*com(0,-0.5);
com dc=(q[i]+q[j].conj())*0.5;
com dd=(q[i]-q[j].conj())*com(0,-0.5);
r[j]=da*dc+da*dd*com(0,1);//打包AC,AD
s[j]=db*dc+db*dd*com(0,1); //打包BC,BD
}
fft(r,n);
fft(s,n);
for(int i=0;i<n;i++){
ll ac,ad,bc,bd;
ac=(ll)(r[i].real/n+0.5)%mod;
ad=(ll)(r[i].imag/n+0.5)%mod;
bc=(ll)(s[i].real/n+0.5)%mod;
bd=(ll)(s[i].imag/n+0.5)%mod;
inc[i]=((ac<<30)+((ad+bc)<<15)+bd)%mod;
}
}
int n,m;
ll a[maxn+5],b[maxn+5],c[maxn+5];
int main(){
scanf("%d %d %lld",&n,&m,&mod);
for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
int N=1,L=0;
while(N<=n+m+1){
L++;
N*=2;
}
for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
mul(a,b,c,N);
for(int i=0;i<n+m+1;i++) printf("%lld ",c[i]);
}