zoukankan      html  css  js  c++  java
  • 拆系数FFT

    学习内容:国家集训队2016论文 - 再谈快速傅里叶变换

    模板题:http://uoj.ac/problem/34

    1.基本介绍

    对长度为L的(A(x),B(x))进行DFT,可以利用

    [egin{align} P(x)=A(x)+iB(x) ag{1} \ Q(x)=A(x)-iB(x) ag{2} end{align} ]

    (P(x))进行DFT,得到(F_p)

    (Q(x))的结果 DFT(F_q[k]=!(F_p[2L-k])),(!表示取共轭)(证明见论文)。

    [egin{align} DFT(A[k])=frac{F_p[k]+F_q[k]} 2 ag{3} \ DFT(B[k])=-ifrac{F_p[k]-F_q[k]} 2 ag{4} end{align} ]

    这就是两两合并计算DFT的方法,2次DFT优化为了1次。

    IDFT的计算有两种方法,一种是带入(-w_n^k),另一种是将序列[1..n-1]翻转,再进行FFT,两种方法结果都要除以n。

    //495ms
    #include <bits/stdc++.h>
    #define rep(i,l,r) for(int i=l,ed=r;i<ed;++i)
    typedef long long ll;
    const double PI = acos(-1);
    const int N = 1<<20;
    const int BUF_SIZE=33554431;
    using namespace std;
    
    struct buf{
        char a[BUF_SIZE],b[BUF_SIZE],*s,*t;
        buf():s(a),t(b){a[fread(a,1,sizeof a,stdin)]=0;}
        ~buf(){fwrite(b,1,t-b,stdout);}
        operator int(){
            int x=0;
            while(*s<48)++s;
            while(*s>32)
                x=x*10+*s++-48;
            return x;
        }
        void out(int x){
            static char c[12];
            char*i=c;
            if(!x)*t++=48;
            else{
                while(x){
                    int y=x/10;
                    *i++=x-y*10+48,x=y;
                }
                while(i!=c)*t++=*--i;
            }
            *t++=10;
        }
    }it;
    struct cp{
        double x,y;
        cp(double _x=0,double _y=0):x(_x),y(_y){}
        cp operator +(const cp&amp; b)const{return cp(x+b.x,y+b.y);}
        cp operator -(const cp&amp; b)const{return cp(x-b.x,y-b.y);}
        cp operator *(const cp&amp; b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
        cp operator !()const{return cp(x,-y);}
    }w[N];
    void fft(cp p[],int n){
        for(int i=0,j=0;i<n;++i){
            if(i>j)swap(p[i],p[j]);
            for(int l=n>>1;(j^=l)<l;l>>=1);
        }
        for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                cp b=w[n/i*k]*p[j+m+k];
                p[j+m+k]=p[j+k]-b;
                p[j+k]=p[j+k]+b;
            }
    }
    void conv(int n,ll *x,ll *y,ll *z){
        static cp p[N],q[N],h(0,-0.25);
        rep(i,0,n){
            w[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
            p[i]=cp(x[i],y[i]);
        }
        fft(p,n);
        rep(i,0,n){
            int j=i?(n-i):0;
            q[j]=(p[i]*p[i]-!p[j]*!p[j])*h;
        }
        fft(q,n);
        rep(i,0,n)z[i]=q[i].x/n+0.5;
    }
    int n,m,p;
    ll a[N],b[N],c[N];
    int main(){
        n=it+1;m=it+1;
        rep(i,0,n) a[i]=it;
        rep(i,0,m) b[i]=it;
        for(n+=m-1,p=1;p<n;p<<=1);
        conv(p,a,b,c);
        rep(i,0,n)it.out(c[i]);
        return 0;
    }
    

    2.更快的卷积

    (A(x))表示为(A_0(x^2)+xA_1(x^2))(、A_0(x^2)、xA_1(x^2))分别是偶次项、奇次项的和。

    那么

    [egin{align} A(x)B(x)&=(A_0(x^2)+xA_1(x^2))(B_0(x^2)+xB_1(x^2))\ &=A_0(x^2)B_0(x^2)+x(A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))+x^2A_1(x^2)B_1(x^2) end{align} ]

    可以分别对(A_0(x)、A_1(x)、B_0(x)、B_1(x))计算DFT,然后再把上式(x^0,x^1,x^2)的系数算出来,再进行3次IDFT。共7次。

    DFT可以两两合并优化为2次,且是两次长度为L(原来是2L)的DFT。

    IDFT时也可以两两合并,于是就需要2次长度L的IDFT。共4次。

    如果这两次IDFT还可以两两合并,那就只要计算一次IDFT。共3次长度L的计算。

    推导如下:

    (A_0(x^2)B_1(x^2)+A_1(x^2)B_0(x^2))的 IDFT 结果就是奇数项的系数。(A_0(x^2)B_0(x^2))(x^2A_1(x^2)B_1(x^2)) 则是偶数项的系数。

    (A_0(x^2)B_0(x^2))(x^2A_1(x^2)B_1(x^2))看做是关于(x^2)的多项式,可以两两合并计算。令

    [g=DFT(A_0)cdot DFT(B_0)+w[k]DFT(A_1)cdot DFT(B_1)\ f=DFT(A_0)cdot DFT(B_1)+DFT(A_1)cdot DFT(B_0) ]

    (xA(x))就是(w_n^kcdot DFT(A))。我们只要计算出(IDFT(g))(IDFT(f))即可。

    如果 IDFT 的结果是实数,那么两个 IDFT 就可以合并计算,令

    [P(x)=g+icdot f ]

    那么

    [IDFT(P(x))=IDFT(f)+i cdot IDFT(g) ]

    于是取实部和虚部分别作为奇数和偶数项的系数即可。

    [j=egin{cases} 0& k=0\ n-k& k eq 0 end{cases} ]

    那么

    [egin{aligned} g&=frac {P_k+!P_j}{2}cdot frac {Q_k+!Q_j}{2}+w[k]cdot frac {P_k-!P_j}{-2i}cdot frac {Q_k-!Q_j}{-2i}\ &=frac 1 4 [(P_k+!P_j)cdot(Q_k+!Q_j)-w[k]cdot(P_k-!P_j)cdot(Q_k-!Q_j)]\ \ f&=frac {P_k+!P_j} 2 cdot frac{Q_k-!Q_j}{-2}i+frac {Q_k+!Q_j} 2 cdot frac{P_k-!P_j}{-2}i\ &=frac i{-4}[2cdot P_kcdot Q_k-2cdot !P_jcdot !Q_j] end{aligned} ]

    于是

    [egin{aligned} g+fcdot i&=frac 1 4 [(P_k+!P_j)cdot(Q_k+!Q_j)-w[k]cdot(P_k-!P_j)cdot(Q_k-!Q_j)-2cdot P_kcdot Q_k+2 !(P_jcdot Q_j)]\ &=frac 1 4 [-(P_k-!P_j)cdot(Q_k-!Q_j)+2cdot (P_kcdot Q_k+!(P_jcdot Q_j))\ &-w[k]cdot(P_k-!P_j)cdot(Q_k-!Q_j)+2cdot P_kcdot Q_k-2cdot !(P_jcdot Q_j)]\ &=Q_kcdot P_k-frac 1 4[(1+w[k])cdot (P_k-!P_j)cdot(Q_k-!Q_j)]\ end{aligned} ]

    //325ms
    #include <bits/stdc++.h>
    #define rep(i,l,r) for(int i=l,ed=r;i<ed;++i)
    typedef long long ll;
    const double PI = acos(-1);
    const int N = 1<<20;
    const int BUF_SIZE=33554431;
    using namespace std;
    
    struct buf{
        char a[BUF_SIZE],b[BUF_SIZE],*s,*t;
        buf():s(a),t(b){a[fread(a,1,sizeof a,stdin)]=0;}
        ~buf(){fwrite(b,1,t-b,stdout);}
        operator int(){
            int x=0;
            while(*s<48)++s;
            while(*s>32)
                x=x*10+*s++-48;
            return x;
        }
        void out(int x){
            static char c[12];
            char*i=c;
            if(!x)*t++=48;
            else{
                while(x){
                    int y=x/10;
                    *i++=x-y*10+48,x=y;
                }
                while(i!=c)*t++=*--i;
            }
            *t++=10;
        }
    }it;
    struct cp{
        double x,y;
        cp(double _x=0,double _y=0):x(_x),y(_y){}
        cp operator +(const cp&amp; b)const{return cp(x+b.x,y+b.y);}
        cp operator -(const cp&amp; b)const{return cp(x-b.x,y-b.y);}
        cp operator *(const cp&amp; b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
        cp operator *(double b)const{return cp(b*x,b*y);}
        cp operator !()const{return cp(x,-y);}
    }w[N];
    void fft(cp *p,int n){
        for(int i=0,j=0;i<n;++i){
            if(i>j)swap(p[i],p[j]);
            for(int l=n>>1;(j^=l)<l;l>>=1);
        }
        for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                cp b=w[n/i*k]*p[j+m+k];
                p[j+m+k]=p[j+k]-b;
                p[j+k]=p[j+k]+b;
            }
    }
    void conv(int n,ll *x,ll *y,ll *z){
        static cp p[N],q[N],a[N];
        rep(i,0,n){
            (i&amp;1?p[i>>1].y:p[i>>1].x)=x[i];
            (i&amp;1?q[i>>1].y:q[i>>1].x)=y[i];
        }
        rep(i,0,n>>=1)w[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
        fft(p,n);fft(q,n);
        rep(i,0,n){
            int j=i?n-i:0;
            a[j]=p[i]*q[i]-((cp(1,0)+w[i])*(p[i]-!p[j])*(q[i]-!q[j]))*0.25;
        }
        fft(a,n);
        rep(i,0,n)z[i<<1]=a[i].x/n+0.5,z[i<<1|1]=a[i].y/n+0.5;
    }
    int n,m,p;
    ll a[N],b[N],c[N];
    int main(){
        n=it+1;m=it+1;
        rep(i,0,n) a[i]=it;
        rep(i,0,m) b[i]=it;
        for(n+=m-1,p=2;p<n;p<<=1);
        conv(p,a,b,c);
        rep(i,0,n)it.out(c[i]);
        return 0;
    }
    

    3.拆系数FFT

    要计算任意模数的卷积,我们一般考虑NTT+中国剩余定理CRT。NTT中需要模数是质数且表示为(p=ccdot 2^k+1)(2^k)要不小于n。

    考虑直接算出卷积不取模,那么每个数不会超过(M^2n)。假设模数(M)(10^9)级别,n是(10^5)级别,那么结果都是(10^{23})级别,我们可以找三个都是(10^9)级别满足NTT要求的模数,利用中国剩余定理就能得到在(10^{27})级别的模数意义下的结果,再对(M)取模即可。

    但是这样常数就要乘3了。效率太低。拆系数FFT就是替代NTT解决模任意数且非常高效的算法。

    如果利用FFT计算,浮点数会有误差,int128是一个方法,但是不是所有场合都能使用。所以需要拆系数。

    (M_0=lceil sqrt M ceil),设

    [a_i=k[a_i]M_0+b[a_i]\ b_i=k[b_i]M_0+b[b_i] ]

    其中(k[a_i],b[a_i]< M_0)

    假设(K_a(x))是以(k[a_i])为系数的多项式,(B_a(x))是以(b[a_i])为系数的多项式,(K_b(x),B_b(x))同理,则:

    [A(x)=K_a(x)M_0+B_a(x)\ B(x)=K_b(x)M_0+B_b(x)\ A(x)B(x)=K_a(x)K_b(x)M_0^2+(K_a(x)B_b(x)+K_b(x)B_a(x))M_0+B_a(x)B_b(x) ]

    和上面「更快的卷积」一样分析,两两合并可以将7次DFT及IDFT计算优化为4次:

    (M_0)可以取一个超过(sqrt M)的2的幂次,比较方便计算。

    [P(x)=K_a(x)+iB_a(x)\ Q(x)=K_b(x)+iB_b(x) ]

    可知

    [DFT(K_a[k])=frac {F_p[k]+!(F_p[(n-k)\%n])} 2\ DFT(B_a[k])=-ifrac {F_p[k]-!(F_p[(n-k)\%n])} 2\ DFT(K_b[k])=frac {F_q[k]-!(F_q[(n-k)\%n])} 2\ DFT(B_b[k])=-ifrac {F_q[k]-!(F_q[(n-k)\%n])} 2\ ]

    于是只要计算出P(x)的DFT:(F_p(x))和Q(x)的DFT:(F_q(x)),就能求出(K_a(x),B_a(x),K_b(x),B_b(x))的DFT。

    接下来IDFT的两两合并,以(K_a(x)K_b(x))(K_a(x)B_b(x))为例,令

    [dfta[k]=DFT(K_a[k])cdot DFT(K_b[k])\ dftb[k]=DFT(K_a[k])cdot DFT(B_b[k]) ]

    我们需要对(dfta(x))(dftb(x))进行IDFT。注意到这里IDFT的结果一定是实数,那么令

    [p[k]=dfta[k]+icdot dftb[k] ]

    那么 (IDFT(p)) 的实部除以n就是(K_a(x)K_b(x)),虚部除以n就是(K_a(x)B_b(x))

    由于(、k[x]、b[x])都是不超过(2^{15})的数,于是就不容易被卡精度了。计算出来的结果再取模M就是答案了。

    //933ms
    #include <bits/stdc++.h>
    #define rep(i,l,r) for(int i=l,ed=r;i<ed;++i)
    typedef long long ll;
    const double PI = acos(-1);
    const int N = 1<<20;
    const ll mod = 1e9+7;
    const int BUF_SIZE=33554431;
    using namespace std;
    
    struct buf{
        char a[BUF_SIZE],b[BUF_SIZE],*s,*t;
        buf():s(a),t(b){a[fread(a,1,sizeof a,stdin)]=0;}
        ~buf(){fwrite(b,1,t-b,stdout);}
        operator int(){
            int x=0;
            while(*s<48)++s;
            while(*s>32)
                x=x*10+*s++-48;
            return x;
        }
        void out(int x){
            static char c[12];
            char*i=c;
            if(!x)*t++=48;
            else{
                while(x){
                    int y=x/10;
                    *i++=x-y*10+48,x=y;
                }
                while(i!=c)*t++=*--i;
            }
            *t++=10;
        }
    }it;
    struct cp{
        double x,y;
        cp(double _x=0,double _y=0):x(_x),y(_y){}
        cp operator +(const cp&amp; b)const{return cp(x+b.x,y+b.y);}
        cp operator -(const cp&amp; b)const{return cp(x-b.x,y-b.y);}
        cp operator *(const cp&amp; b)const{return cp(x*b.x-y*b.y,x*b.y+y*b.x);}
        cp operator !()const{return cp(x,-y);}
    }w[N];
    void fft(cp p[],int n){
        for(int i=0,j=0;i<n;++i){
            if(i>j)swap(p[i],p[j]);
            for(int l=n>>1;(j^=l)<l;l>>=1);
        }
        for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                cp b=w[n/i*k]*p[j+m+k];
                p[j+m+k]=p[j+k]-b;
                p[j+k]=p[j+k]+b;
            }
    }
    void conv(int n,ll *x,ll *y,ll *z){
        static cp p[N],q[N],a[N],b[N],c[N],d[N];
        static cp r(0.5,0),h(0,-0.5),o(0,1);
        rep(i,0,n){
            w[i]=cp(cos(2*PI*i/n),sin(2*PI*i/n));
            x[i]=(x[i]+mod)%mod,y[i]=(y[i]+mod)%mod;
            p[i]=cp(x[i]>>15,x[i]&amp;32767),q[i]=cp(y[i]>>15,y[i]&amp;32767);
        }
        fft(p,n);fft(q,n);
        rep(i,0,n){
            int j=i?(n-i):0;
            static cp ka,ba,kb,bb;
            ka=(p[i]+!p[j])*r;
            ba=(p[i]-!p[j])*h;
            kb=(q[i]+!q[j])*r;
            bb=(q[i]-!q[j])*h;
            a[j]=ka*kb;b[j]=ka*bb;
            c[j]=kb*ba;d[j]=ba*bb;
        }
        rep(i,0,n){
            p[i]=a[i]+b[i]*o;
            q[i]=c[i]+d[i]*o;
        }
        fft(p,n);fft(q,n);
        rep(i,0,n){
            ll a,b,c,d;
            a=(ll)(p[i].x/n+0.5)%mod;
            b=(ll)(p[i].y/n+0.5)%mod;
            c=(ll)(q[i].x/n+0.5)%mod;
            d=(ll)(q[i].y/n+0.5)%mod;
            z[i]=((a<<30)+((b+c)<<15)+d)%mod;
        }
    }
    int n,m,p;
    ll a[N],b[N],c[N];
    int main(){
        n=it+1;m=it+1;
        rep(i,0,n) a[i]=it;
        rep(i,0,m) b[i]=it;
        for(n+=m-1,p=1;p<n;p<<=1);
        conv(p,a,b,c);
        rep(i,0,n)it.out((c[i]+mod)%mod);
        return 0;
    }
    

    题目:

    待补充

  • 相关阅读:
    linux service 例子
    YII2自动初始化脚本
    ubuntu 如何在命令行打开当前目录
    mysql 储存过程
    Mysql 随笔记录
    Lack of free swap space on Zabbix server
    意外发现PHP另一个显示转换类型 binary
    常用的排序代码
    线程的实现方式之内核支持线程和用户级线程
    寻找二叉树中的最低公共祖先结点----LCA(Lowest Common Ancestor )问题(递归)
  • 原文地址:https://www.cnblogs.com/flipped/p/9669209.html
Copyright © 2011-2022 走看看