zoukankan      html  css  js  c++  java
  • FFT/NTT/MTT学习笔记

    FFT/NTT/MTT

    Tags:数学

    作业部落

    评论地址


    前言

    这是网上的优秀博客
    并不建议初学者看我的博客,因为我也不是很了解FFT的具体原理

    一、概述

    两个多项式相乘,不用(N^2),通过(FFT)可以把复杂度优化到(O(NlogN))(NTT)能够取模,(MTT)可以对非(NTT)模数取模,相对来说(FFT)常数小些因为不要取模

    二、我们来背板子(FFT)

    先放一个板子(洛谷P3803 【模板】多项式乘法(FFT)

    #include<iostream>
    #include<cstdio>
    #include<cstdlib>
    #include<cmath>
    using namespace std;
    const int MAXN=3000005;
    const double pi=acos(-1); 
    int N,M,r[MAXN],l;
    struct Complex
    {
    	double rl,im;//real part / imaginary part
    	Complex(){rl=im=0;}//以下是初始化的板子,虽然不懂为什么可以这样写
    	Complex(double a,double b){rl=a,im=b;}
    	Complex operator + (Complex B)
    		{return Complex(rl+B.rl,im+B.im);}
    	Complex operator - (Complex B)
    		{return Complex(rl-B.rl,im-B.im);}
    	Complex operator * (Complex B)
    		{return Complex(rl*B.rl-im*B.im,rl*B.im+im*B.rl);}
    }A[MAXN],B[MAXN];//对A,B两个多项式进行乘法
    void FFT(Complex *P,int op)
    {
    	for(int i=1;i<N;i++)//这个叫Rader排序
    		/*
    		  假设原来P[1...n].id=1..n
    		  现在需要的序列是从1到n所对应的id分别为id[1..n],满足r[id[i]]是升序
    		  r[i]表示把i二进制上第1到l位的数反过来后的十进制数
    		 */			 
    		if(i<r[i]) swap(P[i],P[r[i]]);
    	//接下来的这个叫做蝴蝶操作,算法导论上有一张图较为清晰
    	for(int i=1;i<N;i<<=1)//表示操作区间集的每个区间的长度
    	{
    		Complex W=(Complex){cos(pi/i),op*sin(pi/i)};
    		for(int p=i<<1,j=0;j<N;j+=p)//表示每个区间集的最上端位置
    		{
    			Complex w=(Complex){1,0};//第0个单位复数根
    			/*
    			  转角公式:将一个点(x,y)绕原点逆时针旋转t后的点是(x*cost-y*sint,x*sint+y*cost)
    			  用三角函数和差化积公式容易得证
    			  单位复数根是把单位元分为若干等份,于是每次就要转一定角度
    			  用w=w*W实现转角
    			 */
    			for(int k=0;k<i;k++,w=w*W)//每个区间的最上端位置
    			{
    				Complex X=P[j+k],Y=w*P[j+k+i];//j+k+i便是每个区间下端位置
    				P[j+k]=X+Y;P[j+k+i]=X-Y;//所谓蝴蝶操作
    			}
    		}
    	}
    }
    int main()
    {
        cin>>N>>M;
    	for(int i=0;i<=N;i++) cin>>A[i].rl;
    	for(int i=0;i<=M;i++) cin>>B[i].rl;
    	//读入实部,便是系数
    	M+=N;//最终位数
    	for(N=1;N<=M;N<<=1) l++;l--;//FFT必须是2^k项才能做,这里把他补全
    	for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);//r是rader排序,将每个i的二进制位反过来
    	FFT(A,1);FFT(B,1);//将AB化成点集形式
    	//形如(w0,y0),(w1,y1)...(wn,yn)的这些点确定一条线
    	for(int i=0;i<N;i++) A[i]=A[i]*B[i];//点集O(n)相乘
    	FFT(A,-1);//再将点集转化为系数表示的形式
    	for(int i=0;i<=M;i++) printf("%d ",(int)(A[i].rl/N+0.5));//这时虚部都是0了
    	return 0;
    }
    

    以下是预处理单位复数根的代码
    代码长度会小些,精度也要高,建议使用这种写法
    三角函数比乘法慢

    #include<iostream>
    #include<cstdio>
    #include<cstdlib>
    #include<cmath>
    #include<complex>
    using namespace std;
    const int MAXN=3e6+10;
    const double pi=acos(-1);
    int r[MAXN],N,M,l;
    complex<double>A[MAXN],B[MAXN],w[MAXN];
    void FFT(complex<double> *P,int op)
    {
    	for(int i=1;i<N;i++) if(i>r[i]) swap(P[i],P[r[i]]);
    	for(int i=1;i<N;i<<=1)
    		for(int p=i<<1,j=0;j<N;j+=p)
    			for(int k=0;k<i;k++)
    			{
    				complex<double> W=w[N/i*k];W.imag()*=op;//实际要得到的是cos(pi/i*k)
    				complex<double> X=P[j+k],Y=W*P[j+k+i];//QAQ这里总是忘记乘W
    				P[j+k]=X+Y;P[j+k+i]=X-Y;
    			}
    }
    int main()
    {
    	cin>>N>>M;
    	for(int i=0;i<=N;i++) cin>>A[i].real();
    	for(int i=0;i<=M;i++) cin>>B[i].real();
    	M+=N;
    	for(N=1;N<=M;N<<=1) l++;l--;
    	for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);
    	for(int i=0;i<N;i++) w[i].real()=cos(pi/N*i),w[i]=imag()=sin(pi/N*i);
    	FFT(A,1);FFT(B,1);
    	for(int i=0;i<N;i++) A[i]=A[i]*B[i];
    	FFT(A,-1);
    	for(int i=0;i<=M;i++) printf("%d ",(int)(A[i].real()/N+0.5));
    	puts("");return 0;
    }
    
    

    记忆方式:
    循环的(i)枚举当前处理的长度
    (j)枚举第几组(两组两组进行)
    (k)枚举位置
    于是(j+k)表示某组的第一小组的一个位置,(i+j+k)是某组第二小组与第一小组对应的位置
    然后先加再减,记得乘上(W)

    注意点:
    1.最后要(int)(real()/N+0.5)
    2.由于N要放大所以空间开两倍!!

    三、我们再来背板子(NTT)

    还是那道题

    #include<iostream>
    #include<cstdio>
    #include<cstdlib>
    #include<cmath>
    using namespace std;
    const int N=3000005;
    const int mod=998244353;
    int r[N],l,n,m,A[N],B[N],w[N];
    int ksm(int a,int k)
    {
        int s=1,b=a;
        for(;k;k>>=1,b=1ll*b*b%mod)
            if(k&1) s=1ll*s*b%mod;
        return s;
    }
    void NTT(int *P,int op)
    {
        for(int i=0;i<n;i++) if(i<r[i]) swap(P[i],P[r[i]]);
        for(int i=1;i<n;i<<=1)
        {
            int W=ksm(3,(mod-1)/(i<<1));//3是998244353的一个原根
            if(op<0) W=ksm(W,mod-2);w[0]=1;
            for(int j=1;j<i;j++) w[j]=1ll*w[j-1]*W%mod;
            for(int j=0,p=i<<1;j<n;j+=p)
                for(int k=0;k<i;k++)
                {
                    int X=P[j+k],Y=1ll*w[k]*P[i+j+k]%mod;
                    P[j+k]=(X+Y)%mod;P[i+j+k]=((X-Y)%mod+mod)%mod;
                }
        }
    }
    int main()
    {
        cin>>n>>m;
        for(int i=0;i<=n;i++) cin>>A[i];
        for(int i=0;i<=m;i++) cin>>B[i];
        m=n+m;for(n=1;n<=m;n<<=1) l++;l--;
        for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);
        NTT(A,1);NTT(B,1);
        for(int i=0;i<n;i++) A[i]=1ll*A[i]*B[i]%mod; NTT(A,-1);
        for(int i=0,inv=ksm(n,mod-2);i<n;i++) A[i]=1ll*A[i]*inv%mod;
        for(int i=0;i<=m;i++) printf("%d ",A[i]);
        return 0;
    }
    
    

    一个数(k)的原根(x)满足(x^1,x^2,x^3...x^{phi(k)})各不相同且(x^{phi(k)}=1)
    对于且仅对于(2,4,p,2p,p^r(p为奇质数))有原根存在
    NTT的原根就代替了FFT中的单位复数根,要求形式是(p=r*2^p+1)
    常用的(NTT)模数有(998244353(3))(1004535809(3))

    找质数的原根

    最暴力的方法是枚举原根,然后判断(x^1...x^{p-1})是否相同
    优化的话是检查(p-1)的所有质因数中,是否存在一个质因子(k)使得(x^{frac{p-1}{k}}=1),若存在,则该数不是原根,否则是原根

    证明(Thanks GXY)

    首先可以明确的是,若对于(m)属于([1,p-2]),没有(g^mequiv 1(mod p)),则g是一个原根
    因为(g^{m1}equiv g^{m2} equiv k),且(m1>m2),则一定有(g^{m2-m1}equiv 1)
    利用反证法,假设存在一个(m)使得(g^mequiv 1(mod p))

    分两种情况讨论:
    1.(gcd(p-1,m)!=1)
    (k=(p-1)/m=p_1^{a_1}p_2^{a_2}...p_i^{a_i})(p_i)为质数
    (g^{frac{p-1}{p_i}}=g^{k*p_1^{a_1}*..*P_i^{a_i-1}}=(g^k)^{p_1^{a_1}..p_i^{a_i}}=1),能够通过上述方法判定出来

    2.(gcd(p-1,m)==1)
    (g^mequiv g^{2m}equiv...equiv g^{km}equiv g^{p-1}equiv 1(mod p))
    (kmequiv x(mod p-1)),由于(gcd(m,p-1)==1),根据同余方程的EXGCD判断,(x)可以在([0,p-2])任意取值,都有符合条件的(k)使得式子成立
    根据欧拉定理/费马小定理得(g^{km}equiv g^{km\%(p-1)}equiv g^xequiv 1(mod p)),使得所有的(x)属于([0,p-2])都模p余1,也会在之前的方法中判断出来

    四、有个可以讲清的了(MTT)

    处理任意模数(NTT)问题
    (M=sqrt{mod})(这样子好像复杂度最优)
    然后多项式的每一项拆成(AM+B),于是(A)(B)都在(int)之内就不会爆(double)
    所以两个数相乘就成为了$$(A_1M+B_1)*(A_2M+B_2)=A_1A_2M^2+(A_1B_2+A_2B_1)M+B_1B_2$$分别进行(4)(DFT)(4)(IDFT)即可(一共8次,有些博客是7次,但是代码比我长)

    Code

    洛谷P4245 【模板】任意模数NTT

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<complex>
    #include<cmath>
    using namespace std;
    const double Pi=acos(-1);
    const int N=400100;
    const int M=30000;
    int n,m,p,F[N],G[N];
    int r[N],Ans[N],l,tt;
    complex<double> A1[N],B1[N],A2[N],B2[N],A[N],w[N];
    void FFT(complex<double> *P,int op)
    {
        for(int i=0;i<l;i++) if(r[i]<i) swap(P[i],P[r[i]]);
        for(int i=1;i<l;i<<=1)
            for(int p=i<<1,j=0;j<l;j+=p)
                for(int k=0;k<i;k++)
                {
                    complex<double> W=w[l/i*k];W.imag()*=op;
                    complex<double> X=P[j+k],Y=W*P[j+k+i];
                    P[j+k]=X+Y;P[j+k+i]=X-Y;
                }
    }
    void Work(complex<double> *P1,complex<double> *P2,int base)
    {
        for(int i=0;i<l;i++) A[i]=P1[i]*P2[i];FFT(A,-1);
        for(int i=0;i<=m+n;i++) (Ans[i]+=(long long)(A[i].real()/l+0.5)%p*base%p)%=p;
    }
    int main()
    {
        scanf("%d%d%d",&n,&m,&p);
        for(int i=0,x;i<=n;i++) scanf("%d",&x),A1[i].real()=x/M,B1[i].real()=x%M;
        for(int i=0,x;i<=m;i++) scanf("%d",&x),A2[i].real()=x/M,B2[i].real()=x%M;
        for(l=1;l<=n+m;l<<=1) tt++;tt--;
        for(int i=0;i<l;i++) r[i]=(r[i>>1]>>1)|((i&1)<<tt);
        for(int i=0;i<l;i++) w[i].real()=cos(Pi/l*i),w[i].imag()=sin(Pi/l*i);
        FFT(A1,1);FFT(A2,1);FFT(B1,1);FFT(B2,1);
        Work(A1,A2,M*M%p); Work(A1,B2,M%p);
        Work(A2,B1,M%p); Work(B1,B2,1);
        for(int i=0;i<=m+n;i++) printf("%d ",Ans[i]);
    }
    
    

    五、一些要点

    这一部分还没玩成,待博主把这些算法完全弄懂后再来填坑~
    NTT时(X-Y+mod)%mod时,Y为负数就可能爆int,可以不加mod然后最后输出的时候加
    当乘起来不会超过mod(注意是乘后累加),那么NTT可以代替FFT,否则不行,例子见MTT
    乘法通过原根变成加法再NTT
    字符串匹配问题的两种做法
    组合数公式给拆成可以NTT的形式

  • 相关阅读:
    判断ImageIcon创建成功
    Node中的explorer views的双击事件
    Oracle数据类型
    Sql三种行转列
    数据库迁移
    并发采集同一站点被封的解决方案
    .net获取版本号的三种方法
    List转DataSet
    Orcale自增长主键
    python学习笔记数字和表达式
  • 原文地址:https://www.cnblogs.com/xzyxzy/p/9263480.html
Copyright © 2011-2022 走看看