zoukankan      html  css  js  c++  java
  • [多项式算法](Part 3)MTT 任意模数FFT/NTT 学习笔记

    其他多项式算法传送门:

    [多项式算法](Part 1)FFT 快速傅里叶变换 学习笔记

    [多项式算法](Part 2)NTT 快速数论变换 学习笔记

    [多项式算法](Part 4)FWT 快速沃尔什变换 学习笔记

    [多项式算法](Part 5)分治FFT 学习笔记


    (3.Hard-MTT)

    定义

    • MTT((Maoxiao Theoretic Transforms))

      中文名称:不知道,上面的英文全称也是瞎编的

      (Most TLE Transforms)


    (Q:)现在学了FFT和NTT,那么MTT又是什么?有什么用?

    (A:)有大用

    如果现在需要求两个整数多项式卷积,序列长度(nle10^5),多项式系数(A_i,B_ile 10^9),答案对(ple 10^9)取模。

    这时你就会发现,在运算过程中值域会到达(10^{23})级别!使用FFT会炸精度,而NTT会因为模数的性质而失去作用。

    你可以选择高精度,但是高精不仅难实现,效率也较为低下,而python,java等自带高精的语言在部分赛事中也禁止使用。

    这时我们就需要使用MTT进行运算。


    分析

    MTT有(2)种方法,一种是三模数NTT,然后是拆系数FFT。

    其中NTT精度优秀,但常数较大,而FFT则相反。

    下面对这两种算法进行介绍。


    三模数NTT

    这个算法的主要思想是用(3)个满足NTT性质的(10^9)级别的模数进行NTT,得到(3)个序列,由中国剩余定理可知,因为值域为(10^{23}<10^{27}),所以我们可以由这(3)个序列确定每一个数。

    关于选取模数,可以自己写个程序算,也可以查表,这里推荐Miskcoo大大的表

    这里使用(3)个相加不会炸int的数:(469762049,998244353,1004535809)

    (3)个数原根都是(3),非常方便。

    假设最后得到(3)个序列:(A,B,C),现在要还原第(i)项的答案(x),问题就变成了一个同余方程组:

    [egin{cases} egin{equation} egin{split} xequiv A_i pmod{p_1}\ xequiv B_i pmod{p_2}\ xequiv C_i pmod{p_3} end{split} end{equation} end{cases} ]

    如果直接使用中国剩余定理合并,那么就需要使用int128或者高精度,两者都不太方便。

    我们可以使用EXCRT(拓展中国剩余定理)的方法:

    [egin{equation} egin{split} A_i+k_1p_1&=B_i+k_2p_2\ A_i+k_1p_1&equiv B_i pmod{p_2}\ k_1&equivfrac{B_i-A_i}{p_1} pmod{p_2} end{split} end{equation} ]

    那么就得到前(2)项的解(x=A_i+k_1p_1),接着和第(3)项合并:

    [egin{equation} egin{split} x+k_3p_1p_2&=C_i+k_4p_3\ x+k_3p_1p_2&equiv C_ipmod{p_3}\ k_3&equivfrac{C_i-x}{p_1p_2}pmod{p_3} end{split} end{equation} ]

    于是我们就求出了(3)项的通解(x'=x+k_3p_1p_2),那么答案就是(x'mod{p})


    代码

    综上所述,我们需要做(3)次NTT,即(9)次DFT(IDFT),常数较大(很大,我写得差),请注意常数优化。

    例题:Luogu P4245 【模板】任意模数NTT

    #include <cmath>
    #include <cstdio>
    #include <cctype>
    #include <cstring>
    #include <algorithm>
    #define rint register int
    typedef long long ll;
    
    //Having A Daydream...
    
    char In[1<<20],*p1=In,*p2=In,Ch;
    #define Getchar (p1==p2&&(p2=(p1=In)+fread(In,1,1<<20,stdin),p1==p2)?EOF:*p1++)
    inline int Getint()
    {
        register int x=0;
        while(!isdigit(Ch=Getchar));
        for(;isdigit(Ch);Ch=Getchar)x=x*10+(Ch^48);
        return x;
    }
    
    char Out[22222222],*Outp=Out,St[22],*Tp=St;
    inline void Putint(int x)
    {
        do *Tp++=x%10^48;while(x/=10);
        do *Outp++=*--Tp;while(St!=Tp);
    }
    
    inline ll Pow(ll a,ll b,ll p)
    {
        ll Res=1;
        for(a%=p;b;b>>=1,a=a*a%p)
            if(b&1)Res=Res*a%p;
        return Res;
    }
    
    int r[1<<18];
    namespace Poly
    {
        //const int p[3]={469762049,998244353,1004535809};
        #define Add(a,b) (((a)+(b))>=p?(a)+(b)-p:(a)+(b))
    
        void NTT(int n,int *A,int p,int g)
        {
            for(rint i=0;i<n;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
            for(rint i=2,h=1;i<=n;i<<=1,h<<=1)
                for(rint j=0,Rs=Pow(g,(p-1)/i,p);j<n;j+=i)
                    for(rint k=0,Rt=1;k<h;++k,Rt=(ll)Rt*Rs%p)
                    {
                        int Tmp=(ll)A[j+h+k]*Rt%p;
                        A[j+h+k]=Add(A[j+k],p-Tmp),A[j+k]=Add(A[j+k],Tmp);
                    }
        }
    
        int A[1<<18],B[1<<18];
        void Multiply(int n,int *F,int *G,int p,int *S)
        {
            memcpy(A,F,n*sizeof(int));
            memcpy(B,G,n*sizeof(int));
            NTT(n,A,p,3),NTT(n,B,p,3);
            for(rint i=0;i<n;++i)A[i]=(ll)A[i]*B[i]%p;
            NTT(n,A,p,Pow(3,p-2,p));
            int In=Pow(n,p-2,p);
            for(rint i=0;i<n;++i)S[i]=(ll)A[i]*In%p;
        }
    }
    
    int n,m,p,F[1<<18],G[1<<18],S[3][1<<18];
    const int P[]={469762049,998244353,1004535809};
    
    int main()
    {
        n=Getint(),m=Getint(),p=Getint();
        for(rint i=0;i<=n;++i)F[i]=Getint();
        for(rint i=0;i<=m;++i)G[i]=Getint();
        for(m=n+m,n=1;n<=m;n<<=1);
        for(rint i=0,l=(int)log2(n);i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
        for(rint i=0;i<3;++i)Poly::Multiply(n,F,G,P[i],S[i]);//计算F*G mod P[i],储存在S[i]
        for(rint i=0;i<=m;++i)
        {
            ll x=S[0][i]+((S[1][i]-S[0][i]+P[1])*Pow(P[0],P[1]-2,P[1])%P[1])*P[0];//前2项通项
            ll xs=(x%p+(S[2][i]-x%P[2]+P[2])*Pow((ll)P[0]*P[1],P[2]-2,P[2])%P[2]*P[0]%p*P[1]%p)%p;
            Putint(xs),*Outp++=i==m?'
    ':' ';
        }
        return fwrite(Out,1,Outp-Out,stdout),0;
    }
    

    代码长度 2.40KB

    用时 4.21s

    内存 12.87MB

    Max Case 522ms

    这个(10^5)的MTT时间和我(10^6)NTT时间差不多。。。这个算法可能快赶上(O(nlog^2n))


    拆系数FFT

    (M)为一个常数,把每一个多项式的系数拆成(A*M+B)的形式(两个多项式分别对应(A_1,B_1|A_2,B_2)),有:

    [(A_1M+B_1)(A_2M+B_2)=A_1A_2M^2+(A_1B_2+A_2B_1)M+B_1B_2 ]

    那么我们只需要分别计算(A_1A_2,A_1B_2,A_2B_1,B_1B_2),再相加就可以得到答案。

    (M=sqrt P) 时,上面(4)项都是(O(P))级别,所以FFT的范围在(10^{14})级别,就不会炸。

    (什么?(A_1A_2M^2)不是(O(P^2))级别的吗?)

    其实可以先计算(A_1A_2),最后把(M^2)乘上去的时候取模就好。

    如果分别计算(4)个卷积,这样就需要(12)次DFT(这岂不是比NTT还慢?)

    预处理(A_1,A_2,B_1,B_2)的DFT值可以优化到(7)次DFT:

    DFT((A_1)),DFT((A_2)),DFT((B_1)),DFT((B_2)),IDFT((A_1A_2)),IDFT((A_1B_2+A_2B_1)),IDFT((B_1B_2))

    (Q:) 这不还是很慢?

    其实我们还可以继续向下优化,使用合并DFT的方式可以将DFT优化到(4)次(详情见FFT 学习笔记的底部)

    其中(4)次DFT优化到(2)次,(3)次IDFT优化到(2)次。

    这样就可以跑得快了。(其实可以优化到"(3.5)"次DFT,但效果不明显且复杂,详见myy的2016集训队论文《再探快速傅里叶变换》)


    代码(无优化版):

    照着上面的思路写就可以了

    例题:Luogu P4245 【模板】任意模数NTT

    (7)次DFT,个人觉得比较好写)

    // luogu-judger-enable-o2
    #include <cmath>
    #include <cstdio>
    #include <cctype>
    #include <cstring>
    #include <algorithm>
    #define rint register int
    typedef long long ll;
    typedef long double ld;
    
    //Having A Daydream...
    
    char In[1<<20],*p1=In,*p2=In,Ch;
    #define Getchar (p1==p2&&(p2=(p1=In)+fread(In,1,1<<20,stdin),p1==p2)?EOF:*p1++)
    inline int Getint()
    {
        register int x=0;
        while(!isdigit(Ch=Getchar));
        for(;isdigit(Ch);Ch=Getchar)x=x*10+(Ch^48);
        return x;
    }
    
    char Out[22222222],*Outp=Out,St[22],*Tp=St;
    inline void Putint(int x)
    {
        do *Tp++=x%10^48;while(x/=10);
        do *Outp++=*--Tp;while(St!=Tp);
    }
    
    const double Eps=1e-8,Pi=std::acos(-1),e=std::exp(1);
    struct Complex
    {
        ld x,y;
        inline Complex operator+(const Complex &o)const{return (Complex){x+o.x,y+o.y};}
        inline Complex operator-(const Complex &o)const{return (Complex){x-o.x,y-o.y};}
        inline Complex operator*(const Complex &o)const{return (Complex){x*o.x-y*o.y,x*o.y+y*o.x};}
        inline Complex operator/(const ld k)const{return (Complex){x/k,y/k};}
        inline Complex Conj(){return (Complex){x,-y};}
    }Ome[1<<18],Inv[1<<18];
    
    int r[1<<18];
    namespace Poly
    {
        void Pre(int n)
        {
            for(rint i=0,l=(int)log2(n);i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
            for(rint i=0;i<n;++i)
            {
                ld x=std::cos(2*Pi*i/n),y=std::sin(2*Pi*i/n);
                Ome[i]=(Complex){x,y},Inv[i]=(Complex){x,-y};
            }
        }
    
        void FFT(int n,Complex *A,Complex *T)
        {
            for(rint i=0;i<n;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
            for(rint i=2;i<=n;i<<=1)
                for(rint j=0,h=i>>1;j<n;j+=i)
                    for(rint k=0;k<h;++k)
                    {
                        Complex Tmp=A[j+h+k]*T[n/i*k];
                        A[j+h+k]=A[j+k]-Tmp,A[j+k]=A[j+k]+Tmp;
                    }
        }
    
        Complex A1[1<<18],B1[1<<18],A2[1<<18],B2[1<<18];
        Complex A[1<<18],B[1<<18],C[1<<18];
        void MTT(int n,int p,int *F,int *G,int *S)
        {
            //这里为了方便直接设M=2^15=32768
            for(rint i=0;i<n;++i)
            {
                A1[i].x=F[i]>>15,B1[i].x=F[i]&0x7FFF;
                A2[i].x=G[i]>>15,B2[i].x=G[i]&0x7FFF;
            }
            FFT(n,A1,Ome),FFT(n,B1,Ome),FFT(n,A2,Ome),FFT(n,B2,Ome);
            for(rint i=0;i<n;++i)
            {
                A[i]=A1[i]*A2[i];
                B[i]=A1[i]*B2[i]+A2[i]*B1[i];
                C[i]=B1[i]*B2[i];
            }
            FFT(n,A,Inv),FFT(n,B,Inv),FFT(n,C,Inv);
            for(rint i=0;i<n;++i)
            {
                ll Av=(ll)round(A[i].x/n),Bv=(ll)round(B[i].x/n),Cv=(ll)round(C[i].x/n);
                S[i]=((Av%p<<30)+(Bv%p<<15)+Cv)%p;
            }
        }
    }
    
    int n,m,p,F[1<<18],G[1<<18],S[1<<18];
    
    int main()
    {
        n=Getint(),m=Getint(),p=Getint();
        for(rint i=0;i<=n;++i)F[i]=Getint();
        for(rint i=0;i<=m;++i)G[i]=Getint();
        for(m=n+m,n=1;n<=m;n<<=1);
        Poly::Pre(n),Poly::MTT(n,p,F,G,S);
        for(rint i=0;i<=m;++i)Putint(S[i]),*Outp++=i==m?'
    ':' ';
        return fwrite(Out,1,Outp-Out,stdout),0;
    }
    

    代码长度 2.96KB

    用时 2.59s

    内存 80.93MB

    Max Case 344ms

    Emm比上面的NTT还快了不少,可能我NTT写炸了?(Update:去翻了翻其他人的Code,我写的是个什么东西)

    在考场上推荐这个,简单易懂,缺点就是内存消耗较大,且精度低,需要long double

    Tips:std::coscos精度要高,其他的函数也一样


    代码(DFT优化版):

    (5)次DFT)

    其实我没有看懂IDFT怎么合并来着。。。

    为什么网上的代码全都和我不一样?

    例题:Luogu P4245 【模板】任意模数NTT

    //Luogu O2
    #include <cmath>
    #include <cstdio>
    #include <cctype>
    #include <cstring>
    #include <algorithm>
    #define rint register int
    typedef long long ll;
    typedef long double ld;
    
    //Having A Daydream...
    
    char In[1<<20],*p1=In,*p2=In,Ch;
    #define Getchar (p1==p2&&(p2=(p1=In)+fread(In,1,1<<20,stdin),p1==p2)?EOF:*p1++)
    inline int Getint()
    {
        register int x=0;
        while(!isdigit(Ch=Getchar));
        for(;isdigit(Ch);Ch=Getchar)x=x*10+(Ch^48);
        return x;
    }
    
    char Out[22222222],*Outp=Out,St[22],*Tp=St;
    inline void Putint(int x)
    {
        do *Tp++=x%10^48;while(x/=10);
        do *Outp++=*--Tp;while(St!=Tp);
    }
    
    const double Eps=1e-8,Pi=std::acos(-1),e=std::exp(1);
    struct Complex
    {
        ld x,y;
        inline Complex operator+(const Complex &o)const{return (Complex){x+o.x,y+o.y};}
        inline Complex operator-(const Complex &o)const{return (Complex){x-o.x,y-o.y};}
        inline Complex operator*(const Complex &o)const{return (Complex){x*o.x-y*o.y,x*o.y+y*o.x};}
        inline Complex operator/(const ld k)const{return (Complex){x/k,y/k};}
        inline Complex Conj(){return (Complex){x,-y};}
    }Ome[1<<18],Inv[1<<18],I=(Complex){0,1};
    
    int r[1<<18];
    namespace Poly
    {
        void Pre(int n)
        {
            for(rint i=0,l=(int)log2(n);i<n;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
            for(rint i=0;i<n;++i)
            {
                ld x=std::cos(2*Pi*i/n),y=std::sin(2*Pi*i/n);
                Ome[i]=(Complex){x,y},Inv[i]=(Complex){x,-y};
            }
        }
    
        void FFT(int n,Complex *A,Complex *T)
        {
            for(rint i=0;i<n;++i)if(i<r[i])std::swap(A[i],A[r[i]]);
            for(rint i=2;i<=n;i<<=1)
                for(rint j=0,h=i>>1;j<n;j+=i)
                    for(rint k=0;k<h;++k)
                    {
                        Complex Tmp=A[j+h+k]*T[n/i*k];
                        A[j+h+k]=A[j+k]-Tmp,A[j+k]=A[j+k]+Tmp;
                    }
        }
    
        Complex P[1<<18],Q[1<<18];
        void Double_DFT(int n,Complex *A,Complex *B,Complex *T)
        {
            for(rint i=0;i<n;++i)P[i]=A[i]+B[i]*I,Q[i]=A[i]-B[i]*I;
            FFT(n,P,T);
            for(rint i=0;i<n;++i)Q[i]=(i?P[n-i]:P[0]).Conj();
            for(rint i=0;i<n;++i)A[i]=(P[i]+Q[i])/2,B[i]=(P[i]-Q[i])*I/-2;
        }
    
        Complex A1[1<<18],B1[1<<18],A2[1<<18],B2[1<<18];
        Complex A[1<<18],B[1<<18],C[1<<18];
        void MTT(int n,int p,int *F,int *G,int *S)
        {
            //这里为了方便直接设M=2^15=32768
            for(rint i=0;i<n;++i)
            {
                A1[i].x=F[i]>>15,B1[i].x=F[i]&0x7FFF;
                A2[i].x=G[i]>>15,B2[i].x=G[i]&0x7FFF;
            }
            //FFT(n,A1,Ome),FFT(n,B1,Ome),FFT(n,A2,Ome),FFT(n,B2,Ome);
            Double_DFT(n,A1,B1,Ome),Double_DFT(n,A2,B2,Ome);
            for(rint i=0;i<n;++i)
            {
                A[i]=A1[i]*A2[i];
                B[i]=A1[i]*B2[i]+A2[i]*B1[i];
                C[i]=B1[i]*B2[i];
            }
            FFT(n,A,Inv),FFT(n,B,Inv),FFT(n,C,Inv);
            //Double_DFT(n,A,B,Inv),FFT(n,C,Inv);//IDFT怎么合并?
            for(rint i=0;i<n;++i)
            {
                ll Av=(ll)round(A[i].x/n),Bv=(ll)round(B[i].x/n),Cv=(ll)round(C[i].x/n);
                S[i]=((Av%p<<30)+(Bv%p<<15)+Cv)%p;
            }
        }
    }
    
    int n,m,p,F[1<<18],G[1<<18],S[1<<18];
    
    int main()
    {
        n=Getint(),m=Getint(),p=Getint();
        for(rint i=0;i<=n;++i)F[i]=Getint();
        for(rint i=0;i<=m;++i)G[i]=Getint();
        for(m=n+m,n=1;n<=m;n<<=1);
        Poly::Pre(n),Poly::MTT(n,p,F,G,S);
        for(rint i=0;i<=m;++i)Putint(S[i]),*Outp++=i==m?'
    ':' ';
        return fwrite(Out,1,Outp-Out,stdout),0;
    }
    

    代码长度 3.41KB

    用时 2.21s

    内存 94.60MB

    Max Case 282ms

    优化不是很大来着。。


    总结

    其实MTT也不是很难。只是一个小技巧?

    只是我tcl看不懂,以后就用(7)次DFT吧。。

    参考资料:

    2016国家集训队论文 《再探快速傅里叶变换》 -- 毛啸(myy,matthew99)

  • 相关阅读:
    POJ 1251 Jungle Roads 最小生成树
    HDU 1879 继续畅通工程 最小生成树
    HDU 1875 畅通工程再续 最小生成树
    HDU 1863 畅通工程 最小生成树
    CodeForces 445B DZY Loves Chemistry (并查集)
    UVA 11987 Almost Union-Find (并查集)
    UVALive(LA) 4487 Exclusive-OR(带权并查集)
    UVALive 3027 Corporative Network (带权并查集)
    UVALive(LA) 3644 X-Plosives (并查集)
    POJ 2524 Ubiquitous Religions (并查集)
  • 原文地址:https://www.cnblogs.com/LanrTabe/p/11314179.html
Copyright © 2011-2022 走看看