zoukankan      html  css  js  c++  java
  • FFT/FWT

    最近舟游疯狂出货,心情很好~

    FFT

    FWT


    快速傅里叶变换(FFT)

    具体的推导见这篇:胡小兔 - 小学生都能看懂的FFT!!! (写的很好,不过本小学生第一次没看懂0.0)

    总结下关键内容

    ~ Part 0 ~ 点值表示

    对于一$n$项多项式$A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}$

    我们代入$n$个不同的数$x_i$,得到$n$个值$y_i=A(x_i)$

    则称这$n$个有序数对$(x_i,y_i)$为多项式$A(x)$的点值表示(可以认为是$xOy$平面上的点)

    可以证明(略),通过这个$n$个点值可以还原出原多项式

    ~ Part 1 ~ 多项式表示转点值表示:

    将原多项式$A(x)$的次数补成$2$的幂次,然后进行递归拆分

    设当前一层的多项式为$B(x)$,项数为$n$,即:$B(x)=b_0+b_{1}x+b_{2}x^{2}+...+b_{n-1}x^{n-1}$

    先将此多项式的按照奇偶项系数拆成两个式子

    [B_0(x)=b_0+b_2x+b_4x^2+...+b_{n-2}x^{frac{n}{2}-1}]

    [B_1(x)=b_1+b_3x+b_5x^2+...+b_{n-1}x^{frac{n}{2}-1}]

    那么可以将$B(x)$重新表示:$B(x)=B_0(x^2)+xB_1(x^2)$

    我们希望将复平面上单位圆的点代入,得到$B(omega^{0}_{n})$,...,$B(omega^{n-1}_{n})$的值

    若$0leq i<frac{n}{2}$,则有【这里使用了复数运算律:因为$omega$在单位圆上,所以$omega^i_ncdotomega^j_n=omega^{i+j}_n$;因为$n$为$2$的倍数,所以$omega^{2i}_n=omega^i_{frac{n}{2}}$】

    egin{align*}B(omega^i_n)&=B_0((omega^i_n)^2)+omega^i_nB_1((omega^i_n)^2)\&=B_0(omega^{2i}_n)+omega^i_nB_1(omega^{2i}_n)\&=B_0(omega^i_{frac{n}{2}})+omega^i_nB_0(omega^i_{frac{n}{2}})end{align*}

    而对于单位圆上的另一半,则有【这里追加一个复数运算律:因为$n$为$2$的倍数,所以$omega^{i+frac{n}{2}}_n=-omega^i_n$】

    egin{align*}B(omega^{i+frac{n}{2}}_n)&=B_0((omega^{i+frac{n}{2}}_n)^2)+omega^{i+frac{n}{2}}_nB_1((omega^{i+frac{n}{2}}_n)^2)\&=B_0(omega^{2i+n}_n)-omega^i_nB_1(omega^{2i+n}_n)\&=B_0(omega^i_{frac{n}{2}})-omega^i_nB_0(omega^i_{frac{n}{2}})end{align*}

    所以我们将这一层计算$B(omega^0_n)$,...,$B(omega^{n-1}_{n-1})$的问题变为计算下一层$B_0(omega^0_{frac{n}{2}})$,...,$B_0(omega^{frac{n}{2}-1}_{frac{n}{2}})$,$B_1(omega^0_{frac{n}{2}})$,...,$B_1(omega^{frac{n}{2}-1}_{frac{n}{2}})$

    这样,总的复杂度是$O(N logN)$

    ~ Part 2 ~ 点值转多项式系数

    设原多项式为$A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}$,其中补成$2$的幂次项后项数为$n$

    则分别代入$omega^0_n$,$omega^1_n$,...,$omega^{n-1}_n$,得到$y_i=A(omega^i_n)$,$i=0,1,...,n-1$,即上面得到的$n$个点值

    令$B(x)=y_0+y_1x+y_2x^2+...+y_{n-1}x^{n-1}$

    再分别代入$omega^0_n$,$omega^{-1}_n$,...,$omega^{-(n-1)}_n$,运算(这里略去)得到$a_i=frac{B(omega^{-i}_n)}{n}$,$i=0,1,...,n-1$

    按这个思路写出的代码如下(实现了多项式系数转点值,再转回多项式系数)

    大佬用的数组指针确实方便

    #include <cstdio>
    #include <complex>
    #include <cmath>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    
    typedef complex<double> cp;
    const int N=100005;
    const double pi=acos(-1.0);
    
    inline cp omega(int x,int n,int rev)
    {
        if(rev)
            x=-x;
        return cp(cos(x*2*pi/(double)n),sin(x*2*pi/(double)n));
    }
    
    cp tmp[N];
    
    void fft(cp *a,int n,int rev)
    {
        if(n==1)
            return;
        
        for(int i=0;i<n;i++)
            tmp[i]=a[i];
        for(int i=0;i<n/2;i++)
        {
            a[i]=tmp[i*2];
            a[n/2+i]=tmp[i*2+1];
        }
        
        fft(a,n/2,rev);
        fft(a+n/2,n/2,rev);
        
        for(int i=0;i<n/2;i++)
        {
            cp x=omega(i,n,rev);
            tmp[i]=a[i]+x*a[n/2+i];
            tmp[n/2+i]=a[i]-x*a[n/2+i];
        }
        
        for(int i=0;i<n;i++)
            a[i]=tmp[i];
    }
    
    int n;
    cp a[N];
    
    int main()
    {
    //    freopen("input.txt","r",stdin);
        scanf("%d",&n);
        for(int i=0;i<n;i++)
        {
            int x;
            scanf("%d",&x);
            a[i]=cp((double)x,0);
        }
        
        int sz=1;
        while(sz<n)
            sz<<=1;
        n=sz;
        
        fft(a,n,0);
        fft(a,n,1);
        
        for(int i=0;i<n;i++)
            printf("%.2lf ",a[i].real()/(double)n);
        printf("
    ");
        return 0;
    }
    View Code

    ~ Part 3 ~  两个简单的优化

    非递归:

    将递归的第一部分,即交换$a[i]$以奇偶分组,直接完成

    这里的交换位置是两两一对的——若$n=8$,则$4$($100$)最终将与$1$($001$)互换

    即两个交换位置的二进制是前后翻转的

    去除$tmp[i]$数组:

    对于递归的第二部分for循环中的每个$i$,只会用到$tmp[i]$、$tmp[n/2+i]$、$a[i]$、$a[n/2+i]$

    所以对于当前$i$的改变是无后效性的,可以考虑使用临时变量而不是$tmp[i]$数组

    于是,可以写出没有递归与$tmp[i]$数组的$fft()$函数

    void fft(cp *a,int n,int rev)
    {
        for(int i=0;i<n;i++)
        {
            int nxt=0;
            for(int j=1;j<=n;j<<=1)
                if(i&j)
                    nxt+=n/2/j;
            
            if(nxt>i)
                swap(a[i],a[nxt]);
        }
        
        for(int i=2;i<=n;i<<=1)
            for(int j=0;j<n;j+=i)
                for(int k=0;k<i/2;k++)
                {
                    cp x=omega(k,i,rev);
                    cp b0=a[j+k],b1=a[i/2+j+k];
                    a[j+k]=b0+x*b1;
                    a[i/2+j+k]=b0-x*b1;
                }
    }
    View Code

    现在一般用的是手写complex、有预处理的版本

    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <algorithm>
    using namespace std;
    
    struct cp
    {
        double x,y;
        cp(double a=0.0,double b=0.0)
        {
            x=a,y=b;
        }
    };
    
    inline cp operator *(cp a,cp b)
    {
        return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
    }
    inline cp operator +(cp a,cp b)
    {
        return cp(a.x+b.x,a.y+b.y);
    }
    inline cp operator -(cp a,cp b)
    {
        return cp(a.x-b.x,a.y-b.y);
    }
    
    const int N=300005;
    const double pi=acos(-1.0);
    const double eps=0.00001;
    const int op[2]={1,-1};
    
    int to[N<<2];
    cp omega[N<<2][2];
    
    void Init(int n)
    {
        for(int i=0;i<n;i++)
        {
            int nxt=0;
            for(int j=1;j<=n;j<<=1)
                if(i&j)
                    nxt+=n/2/j;
            to[i]=nxt;
        }
        for(int i=0;i<n/2;i++)
            for(int j=0;j<2;j++)
            {
                double ang=op[j]*i*2*pi/(double)n;
                omega[i][j]=cp(cos(ang),sin(ang));
            }
    }
    
    void fft(cp *a,int n,int rev)
    {
        for(int i=0;i<n;i++)
            if(to[i]>i)
                swap(a[i],a[to[i]]);
        
        for(int i=1;i<=n;i<<=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    int p=j+k;
                    cp b0=a[p],t=omega[n/i*k][rev]*a[m+p];
                    a[p]=b0+t;
                    a[m+p]=b0-t;
                }
        }
    }
    View Code

    用FFT计算卷积

    一般的卷积有这种形式:$c_k=sumlimits_{i+j=k}^{}a_icdot b_j$

    而这恰好跟多项式乘法的形式是完全一致的

    所以可以用$O(nlogn)$的时间,计算出两个$n$维向量的卷积

    模板题:CF Gym101002 E ($K-Inversions$)

    将$A$、$B$分别用$1$表示一次,拆出两个数组

    这里由于是$j-i=k$,所以可以将$B$对应的数组反向来表示负次幂的整体平移

    然后两数组转为点值后乘起来,再转回去,就得到了结果;不过要注意并不是将整个卷积输出

    不预处理会TLE orz

    #include <cstdio>
    #include <complex>
    #include <cmath>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    
    typedef complex<double> cp;
    const int N=1000005;
    const double pi=acos(-1.0);
    const double eps=0.00001;
    const int op[2]={1,-1};
    
    cp omega[N<<2][2];
    
    void Init(int n)
    {
        for(int i=0;i<n/2;i++)
            for(int j=0;j<2;j++)
                omega[i][j]=cp(cos(op[j]*i*2*pi/(double)n),sin(op[j]*i*2*pi/(double)n));
    }
    
    void fft(cp *a,int n,int rev)
    {
        for(int i=0;i<n;i++)
        {
            int nxt=0;
            for(int j=1;j<=n;j<<=1)
                if(i&j)
                    nxt+=n/2/j;
            
            if(nxt>i)
                swap(a[i],a[nxt]);
        }
        
        for(int i=1;(1<<i)<=n;i++)
        {
            int m=1<<i;
            for(int j=0;j<n;j+=m)
                for(int k=0;k<m/2;k++)
                {
                    cp x=omega[n/m*k][rev];
                    cp b0=a[j+k],b1=a[m/2+j+k];
                    a[j+k]=b0+x*b1;
                    a[m/2+j+k]=b0-x*b1;
                }
        }
    }
    
    int n;
    char s[N];
    cp a[N<<2],b[N<<2],c[N<<2];
    
    int main()
    {
        scanf("%s",s);
        n=strlen(s);
        int sz=1;
        while(sz<n)
            sz<<=1;
        
        for(int i=0;i<n;i++)
            if(s[i]=='B')
                a[sz-i-1]=cp(1.0,0.0);
            else
                b[i]=cp(1.0,0.0);
        sz<<=1;
        
        Init(sz);
        fft(a,sz,0);
        fft(b,sz,0);
        
        for(int i=0;i<sz;i++)
            c[i]=a[i]*b[i];
        
        fft(c,sz,1);
        
        for(int i=sz/2;i<sz/2+n-1;i++)
            printf("%d
    ",int(c[i].real()/(double)sz+eps));
        return 0;
    }
    View Code

    一道不那么裸的题:CF 528D ($Fuzzy$ $Search$)

    这题的tutorial写的很好,就不具体写了

    大概思路是,把$ATGC$各看成一个subtask

    由于$k$的存在可以对字符串$S$进行左右平移,然后将$T$的倒序与平移后的$S$卷积,得出能匹配的数量

    #include <cstdio>
    #include <complex>
    #include <cmath>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    
    typedef complex<double> cp;
    const int N=280005;
    const double pi=acos(-1.0);
    const double eps=0.00001;
    const int op[2]={1,-1};
    
    cp omega[N<<1][2];
    
    void Init(int n)
    {
        for(int i=0;i<n/2;i++)
            for(int j=0;j<2;j++)
                omega[i][j]=cp(cos(op[j]*i*2*pi/(double)n),sin(op[j]*i*2*pi/(double)n));
    }
    
    void fft(cp *a,int n,int rev)
    {
        for(int i=0;i<n;i++)
        {
            int nxt=0;
            for(int j=1;j<=n;j<<=1)
                if(i&j)
                    nxt+=n/2/j;
            
            if(nxt>i)
                swap(a[i],a[nxt]);
        }
        
        for(int i=1;(1<<i)<=n;i++)
        {
            int m=1<<i;
            for(int j=0;j<n;j+=m)
                for(int k=0;k<m/2;k++)
                {
                    cp x=omega[n/m*k][rev];
                    cp b0=a[j+k],b1=a[m/2+j+k];
                    a[j+k]=b0+x*b1;
                    a[m/2+j+k]=b0-x*b1;
                }
        }
    }
    
    int n,m,k;
    char S[N],T[N];
    
    void Push(cp *a)
    {
        int lv=0;
        for(int i=0;i<n;i++,lv--)
        {
            if(a[i].real()>0.0)
                lv=k+1;
            if(lv>0)
                a[i]=cp(1.0,0.0);
        }
        
        lv=0;
        for(int i=n-1;i>=0;i--,lv--)
        {
            if(a[i].real()>0.0)
                lv=k+1;
            if(lv>0)
                a[i]=cp(1.0,0.0);
        }
    }
    
    int to[130];
    cp a[4][N<<1],b[4][N<<1];
    int cnt[4];
    
    int main()
    {
        scanf("%d%d%d",&n,&m,&k);
        scanf("%s",S);
        scanf("%s",T);
        int sz=1;
        while(sz<n)
            sz<<=1;
        
        to['A']=0,to['T']=1,to['G']=2,to['C']=3;
        for(int i=0;i<n;i++)
            a[to[S[i]]][i]=cp(1.0,0.0);
        for(int i=0;i<m;i++)
        {
            b[to[T[i]]][sz-i-1]=cp(1.0,0.0);
            cnt[to[T[i]]]++;
        }
        sz<<=1;
        
        for(int i=0;i<4;i++)
            Push(a[i]);
        
        Init(sz);
        for(int i=0;i<4;i++)
        {
            fft(a[i],sz,0);
            fft(b[i],sz,0);
        }
        
        for(int i=0;i<4;i++)
            for(int j=0;j<sz;j++)
                a[i][j]=a[i][j]*b[i][j];
        
        for(int i=0;i<4;i++)
            fft(a[i],sz,1);
        
        int ans=0;
        for(int i=n/2;i<sz/2+n-1;i++)
        {
            bool flag=true;
            for(int j=0;j<4;j++)
            {
                if(int(a[j][i].real()/(double)sz+eps)!=cnt[j])
                    flag=false;
            }
            
            if(flag)
                ans++;
        }
        printf("%d
    ",ans);
        return 0;
    }
    View Code

    还有一道类似的题:Luogu P4173 (残缺的字符串) 

    虽然题目差不多,但做法还是有些差距的

    学习了:Ebola - 题解 P4173 【残缺的字符串】 思路写的非常完善

    方法大概是,构造函数→暴力展开→(调整遍历顺序)→卷积

    有长为$n$的原串$A$与长为$m$的匹配串$B$,则匹配函数$f(x)=sum_{i=0}^{m-1}(A[x+i-m+1]-B[i])^2cdot A[x+i-m+1]B[i]$表示匹配结尾为$x$的匹配值($0$则完全匹配)

    将$B$倒序为$R$,则有$R[-i+m-1]=B[i]$,那么可以改写匹配函数为$f(x)=sum_{i=0}^{m-1}(A[x+i-m+1]-R[-i+m-1])^2cdot A[x+i-m+1]R[-i+m-1]$

    展开后即可发现明显的卷积:$f(x)=sum A[x+i-m+1]^3R[-i+m-1]-2sum A[x+i-m+1]^2R[-i+m-1]^2+sum A[x+i-m+1]R[-i+m-1]^2$

    常数是真的卡不过去了ToT 差不多是刚好超时的样子

    #include <cstdio>
    #include <complex>
    #include <cmath>
    #include <cstring>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    struct cp
    {
        double x,y;
        cp(double a=0.0,double b=0.0)
        {
            x=a,y=b;
        }
    };
    
    cp operator *(cp a,cp b)
    {
        return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
    }
    cp operator +(cp a,cp b)
    {
        return cp(a.x+b.x,a.y+b.y);
    }
    cp operator -(cp a,cp b)
    {
        return cp(a.x-b.x,a.y-b.y);
    }
    
    const int N=300005;
    const double pi=acos(-1.0);
    const double eps=0.00001;
    const int op[2]={1,-1};
    const int mul[4]={0,1,-2,1};
    
    int to[N<<2];
    cp omega[N<<2][2];
    
    void Init(int n)
    {
        for(int i=0;i<n;i++)
        {
            int nxt=0;
            for(int j=1;j<=n;j<<=1)
                if(i&j)
                    nxt+=n/2/j;
            to[i]=nxt;
        }
        for(int i=0;i<n/2;i++)
            for(int j=0;j<2;j++)
                omega[i][j]=cp(cos(op[j]*i*2*pi/(double)n),sin(op[j]*i*2*pi/(double)n));
    }
    
    void fft(cp *a,int n,int rev)
    {
        for(int i=0;i<n;i++)
            if(to[i]>i)
                swap(a[i],a[to[i]]);
        
        for(int i=1;(1<<i)<=n;i++)
        {
            int m=1<<i;
            for(int j=0;j<n;j+=m)
                for(int k=0;k<m/2;k++)
                {
                    cp x=omega[n/m*k][rev];
                    cp b0=a[j+k],b1=a[m/2+j+k];
                    a[j+k]=b0+x*b1;
                    a[m/2+j+k]=b0-x*b1;
                }
        }
    }
    
    int n,m;
    char A[N],B[N];
    
    cp a[N<<2],b[N<<2];
    cp f[N<<2];
    
    vector<int> ans;
    
    int main()
    {
        scanf("%d%d",&m,&n);
        scanf("%s",B);
        scanf("%s",A);
        
        int sz=1;
        while(sz<n)
            sz<<=1;
        sz<<=1;
        Init(sz);
        
        for(int i=1;i<=3;i++)
        {
            for(int j=0;j<n;j++)
            {
                int val=(A[j]=='*'?0:A[j]-'a'+1);
                a[j]=cp(pow(val,4-i),0);
            }
            for(int j=n;j<sz;j++)
                a[j]=cp(0,0);
            for(int j=0;j<m;j++)
            {
                int val=(B[-j+m-1]=='*'?0:B[-j+m-1]-'a'+1);
                b[j]=cp(pow(val,i),0);
            }
            for(int j=m;j<sz;j++)
                b[j]=cp(0,0);
            
            fft(a,sz,0);
            fft(b,sz,0);
            
            for(int j=0;j<sz;j++)
                f[j]=f[j]+cp(mul[i],0)*a[j]*b[j];
        }
        
        fft(f,sz,1);
        
        for(int i=m-1;i<n;i++)
            if(int(f[i].x/(double)sz+eps)==0)
                ans.push_back(i-m+2);
        
        printf("%d
    ",ans.size());
        for(int i=0;i<ans.size();i++)
            printf("%d ",ans[i]);
        
        return 0;
    }
    View Code

    神仙题:CF 623E ($Transforming$ $Sequence$)

    首先要想到tutorial的并归思想,然后要对式子变形成可卷积的形式(即乘积中的一部分只有$i$,另一部分只有$n-i$)

    FFT还要能取模...(貌似就是NTT?)

    暂时就不做了orz


    快速沃尔什变换(FWT)

    参考了这篇:yyb - FWT快速沃尔什变换学习笔记 不过思路还是有些区别

    FWT主要解决的问题是位运算卷积,比如:$C_k=sum_{i ext{|}j=k}A_icdot B_j$,$C_k=sum_{i ext{&}j=k}A_icdot B_j$,$C_k=sum_{i ext{^}j=k}A_icdot B_j$

    这个跟FFT还是有一些区别的,毕竟FFT可以用多项式乘法形象地理解

    ~ Part 0 ~ 从FFT到FWT

    现在重新考虑一下FFT做了什么事情:

    相当于构造了对一维向量$A$的一种变换方式$A ightarrow FFT(A)$

    使得对于两个一维向量$A$、$B$以及它们的卷积$C$,有$FFT(C)=FFT(A)cdot FFT(B)$(暂且先不考虑逆变换)

    这里,我们是用点值来考虑的——因为$FFT(A)[i]=A(omega_n^i)$,所以上式相当于$forall i,C(omega_n^i)=A(omega_n^i)cdot B(omega_n^i)$

    正确性是显然的

    然后我们类似的考虑FWT要做的事情(以or卷积为例):

    构造对一维向量$A$的一种变换方式$A ightarrow FWT(A)$

    使得对于两个一维向量$A$、$B$以及它们的or卷积$C$,有$FWT(C)=FWT(A)cdot FWT(B)$

    类似于FFT,我们并不需要知道如何推导出的变换(即为什么代入$omega_n^i$),只需要知道变换方式的正确性

    ~ Part 1 ~ or卷积的FWT

    现给出or卷积的FWT变换:

    记向量$A$的项数为$n$,且$n$为$2$的幂次,则$A=(a_0,a_1,...,a_{n-1})$

    拆出$A$的前后一半,分别记为$A_0$与$A_1$,则$A_0=(a_0,a_1,...,a_{frac{n}{2}-1})$,$A_1=(a_{frac{n}{2}},a_{frac{n}{2} +1},...,a_{n-1})$

    [FWT(A)=egin{cases}(FWT(A_0),FWT(A_0+A_1)) &left| A ight| >1\A &left| A ight| =1end{cases}]

    ~ Part 2 ~ 正确性的证明

    通过观察,我们可以很明显的发现一个性质

    对于$0 ext{~}frac{n}{2}-1$,它的二进制第一位是$0$;而$frac{n}{2} ext{~}n-1$的二进制第一位是$1$

    那么这个变换的步骤,就是将位置编号为$i$的元素不停地向后加——其中的“向后”是指,将$i$二进制中的$0$依次变为$1$

    比如,取$n=8$,则变换的步骤为

    可以发现,$FWT(A)[i]=sum_{i ext{|}j=i}a_j$

    那么,有[egin{align*}& FWT(A)[i]cdot FWT(B)[i]\&=sum_{i ext{|}j=i}a_jcdot sum_{i ext{|}j=i}b_j\&=sum_{i ext{|}j=i}(a_jcdot sum_{i ext{|}k=i}b_k)\&=sum_{i ext{|}j=i,i ext{|}k=i}a_jcdot b_kend{align*}]

    而根据or卷积的定义,$C_i=sum_{j ext{|}k=i}a_jcdot b_k$

    则有[egin{align*}FWT(C)[i]&=sum_{i ext{|}j=i}C_j\&=sum_{i ext{|}j=i}sum_{k ext{|}l=j}a_kcdot b_l\&=sum_{i ext{|}k=i}a_kcdot sum_{i ext{|}l=i}b_lquad ext{此时j=k|l}\&=sum_{i ext{|}j=i,i ext{|}k=i}a_icdot b_jend{align*}]

    由于$forall i,FWT(C)[i]=FWT(A)[i]cdot FWT(B)[i]$,所以$FWT(C)=FWT(A)cdot FWT(B)$,证毕

    所以可以放心食用or卷积了~

    ~ Part 3 ~ FWT的结论(们)

    现在给出三种位运算卷积的FWT

    其余两种的证明方法应该是类似的(好像xor有点差距0.0 不管了)

    对于or卷积:$FWT(A)=egin{cases}(FWT(A_0),FWT(A_0+A_1)) &left| A ight| >1\A &left| A ight| =1end{cases}$

    对于and卷积:$FWT(A)=egin{cases}(FWT(A_0+A_1),FWT(A_1)) &left| A ight| >1\A &left| A ight| =1end{cases}$

    对于xor卷积:$FWT(A)=egin{cases}(FWT(A_0+A_1),FWT(A_0-A_1)) &left| A ight| >1\A &left| A ight| =1end{cases}$

    也可以方便的找到逆变换IFWT(就是和差问题)

    设向量$B=FWT(A)$,同样对半分成$B_0$,$B_1$

    对于or卷积:$IFWT(B)=egin{cases}(IFWT(B_0),IFWT(B_1-B_0)) &left| B ight| >1\B &left| B ight| =1end{cases}$

    对于and卷积:$IFWT(B)=egin{cases}(IFWT(B_0-B_1),IFWT(B_1)) &left| B ight| >1\B &left| B ight| =1end{cases}$

    对于xor卷积:$IFWT(B)=egin{cases}(frac{IFWT(B_0+B_1)}{2},frac{IFWT(B_0-B_1)}{2}) &left| B ight| >1\B &left| B ight| =1end{cases}$

    还有两个比较常用的运算律,成立的原因在于FWT只是简单的加减变换

    $FWT(Acdot A)=FWT(A)cdot FWT(A)$

    $FWT(A+A)=FWT(A)+FWT(A)$

    模板题:Luogu P4717 (【模板】快速沃尔什变换)

    #include <algorithm>
    #include <cstring>
    #include <cstdio>
    using namespace std;
    
    typedef long long ll;
    const int N=1<<17;
    const int MOD=998244353;
    
    inline ll mod(ll x)
    {
        if(x>=MOD)
            return x-MOD;
        if(x<0)
            return x+MOD;
        return x;
    }
    
    ll rev=1;
    
    void Init()
    {
        ll n=MOD-2,x=2;
        while(n)
        {
            if(n&1)
                rev=(rev*x)%MOD;
            x=(x*x)%MOD;
            n>>=1;
        }
    }
    
    void FWTor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+m+k]=mod(a[j+k]+a[j+m+k]);
        }
    }
    
    void IFWTor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+m+k]=mod(a[j+m+k]-a[j+k]);
        }
    }
    
    void FWTand(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+k]=mod(a[j+k]+a[j+m+k]);
        }
    }
    
    void IFWTand(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+k]=mod(a[j+k]-a[j+m+k]);
        }
    }
    
    void FWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=mod(x+y);
                    a[j+m+k]=mod(x-y);
                }
        }
    }
    
    void IFWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=(x+y)*rev%MOD;
                    a[j+m+k]=(x-y)*rev%MOD;
                }
        }
    }
    
    ll A[N],B[N];
    ll a[N],b[N];
    
    int main()
    {
        int n;
        scanf("%d",&n);
        n=1<<n;
        for(int i=0;i<n;i++)
            scanf("%d",&A[i]);
        for(int i=0;i<n;i++)
            scanf("%d",&B[i]);
        
        Init();
        
        for(int i=0;i<n;i++)
            a[i]=A[i],b[i]=B[i];
        FWTor(a,n);
        FWTor(b,n);
        for(int i=0;i<n;i++)
            a[i]=(a[i]*b[i])%MOD;
        IFWTor(a,n);
        for(int i=0;i<n;i++)
            printf("%lld ",mod(a[i]));
        printf("
    ");
        
        
        for(int i=0;i<n;i++)
            a[i]=A[i],b[i]=B[i];
        FWTand(a,n);
        FWTand(b,n);
        for(int i=0;i<n;i++)
            a[i]=(a[i]*b[i])%MOD;
        IFWTand(a,n);
        for(int i=0;i<n;i++)
            printf("%lld ",mod(a[i]));
        printf("
    ");
        
        for(int i=0;i<n;i++)
            a[i]=A[i],b[i]=B[i];
        FWTxor(a,n);
        FWTxor(b,n);
        for(int i=0;i<n;i++)
            a[i]=(a[i]*b[i])%MOD;
        IFWTxor(a,n);
        for(int i=0;i<n;i++)
            printf("%lld ",mod(a[i]));
        printf("
    ");
        
        return 0;
    }
    View Code

    rls推荐的一题:BZOJ 4589 ($Hard$ $Nim$)

    题目就是求有多少$a[i]$的选择方案,使得$a[1] ext{^}a[2] ext{^}... ext{^}a[n]=0$

    受到上面$Transforming$ $Sequence$的启发(因为题目中$n$很大,$m$较小),我们可以考虑并归

    一个长度为$n$的区间可以通过多个区间合并而来,比如可以用类似快速幂的方法,每次合并进一个长度为$2$的幂次的区间

    然后“合并”操作恰恰就是xor卷积:$C_k=sum_{i ext{^}j=k}a_icdot b_j$,其中$a_i$表示整个区间异或值为$i$的方案数

    $a_i$的初始值为:$a_i=egin{cases}1& ext{i为质数}\0& ext{i为合数}end{cases}$(题目中要求每堆石子个数是质数)

    又由$FWT(Acdot A)=FWT(A)cdot FWT(A)$,我们只需对$A$在循环外做一次FWTxor变换,在合并时只需要将$FWT(A)$点乘自己

    总的复杂度为$O(Tcdot k(logn+logk))$

    #include <algorithm>
    #include <cstring>
    #include <cstdio>
    using namespace std;
    
    typedef long long ll;
    const int N=1<<19;
    const int MOD=1000000007;
    
    inline ll mod(ll x)
    {
        if(x>=MOD)
            return x-MOD;
        if(x<0)
            return x+MOD;
        return x;
    }
    
    ll rev=1;
    
    void Init()
    {
        ll n=MOD-2,x=2;
        while(n)
        {
            if(n&1)
                rev=(rev*x)%MOD;
            x=(x*x)%MOD;
            n>>=1;
        }
    }
    
    void FWTor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+m+k]=mod(a[j+k]+a[j+m+k]);
        }
    }
    
    void IFWTor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+m+k]=mod(a[j+m+k]-a[j+k]);
        }
    }
    
    void FWTand(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+k]=mod(a[j+k]+a[j+m+k]);
        }
    }
    
    void IFWTand(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+k]=mod(a[j+k]-a[j+m+k]);
        }
    }
    
    void FWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=mod(x+y);
                    a[j+m+k]=mod(x-y);
                }
        }
    }
    
    void IFWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=(x+y)*rev%MOD;
                    a[j+m+k]=(x-y)*rev%MOD;
                }
        }
    }
    
    int prime[N];
    
    ll a[N],res[N];
    
    int main()
    {
        Init();
        for(int i=2;i<N;i++)
            prime[i]=1;
        for(int i=2;i<N;i++)
            if(prime[i])
                for(int j=i+i;j<N;j+=i)
                    prime[j]=0;
        
        int n,m;
        while(scanf("%d%d",&n,&m)!=EOF)
        {
            int sz=1;
            while(sz<=m)
                sz<<=1;
            
            for(int i=0;i<sz;i++)
            {
                a[i]=(i<=m?prime[i]:0);
                res[i]=1;
            }
            
            FWTxor(a,sz);
            
            while(n)
            {
                if(n&1)
                    for(int i=0;i<sz;i++)
                        res[i]=res[i]*a[i]%MOD;
                
                for(int i=0;i<sz;i++)
                    a[i]=a[i]*a[i]%MOD;
                n>>=1;
            }
            
            IFWTxor(res,sz);
            
            printf("%d
    ",res[0]);
        }
        
        
        return 0;
    }
    View Code

    依旧是rls推荐的一题:HDU 5909 ($Tree$ $Cutting$)

    先看一眼,$n$挺小,嗯...(意味深长)

    这道题是对一个无根树的子树求xor和,首先不妨定一个根

    定下根后,整棵树的层次就确定了;于是,想要统计所有无根子树,就相当于从下向上、统计每个节点的子树,并把结果加起来

    正确性很好证明:对于以$x$为根节点的子树,包含$x$的所有无根子树在$x$处被统计,不包含$x$的在$x$的后代处被统计

    那么现在要解决的问题变成,已知儿子的统计信息,如何得出当前节点$cur$的结果

    首先,当前节点是必选的

    然后,我们可以在 以$cur$的儿子们为根 的子树中,任意选择节点

    对于每一个儿子,选择的结果就是已知的统计信息;而想要将所有儿子的选择合并,恰好就是卷积

    比如,在儿子$1$的子树中xor值为$i$的选择方案数$a_i$,与儿子$2$的子树中xor值为$j$的选择方案数$b_j$,对当前节点xor值为$v_{cur} ext{^} i ext{^} j$($v_i$为题目中节点的权值)的贡献数为$a_icdot b_j$

    大方向有了,还有三个细节问题

    1. 我们对于每个节点应当存储的是FWTxor的值,否则对于一个儿子的合并不能做到$O(n)$,而会多一个$logn$

    2. 儿子的子树中可以一个节点都不选,即xor值为$0$的方案在计算$cur$时应加$1$;对于FWT,相当于对全体加$1$

    3. 输出时,行末无空格,文末有回车

    #include <algorithm>
    #include <cstring>
    #include <cstdio>
    #include <vector>
    using namespace std;
    
    typedef long long ll;
    const int N=1<<10;
    const int MOD=1000000007;
    
    inline ll mod(ll x)
    {
        if(x>=MOD)
            return x-MOD;
        if(x<0)
            return x+MOD;
        return x;
    }
    
    ll rev=1;
    
    void Init()
    {
        ll n=MOD-2,x=2;
        while(n)
        {
            if(n&1)
                rev=(rev*x)%MOD;
            x=(x*x)%MOD;
            n>>=1;
        }
    }
    
    void FWTor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+m+k]=mod(a[j+k]+a[j+m+k]);
        }
    }
    
    void IFWTor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+m+k]=mod(a[j+m+k]-a[j+k]);
        }
    }
    
    void FWTand(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+k]=mod(a[j+k]+a[j+m+k]);
        }
    }
    
    void IFWTand(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                    a[j+k]=mod(a[j+k]-a[j+m+k]);
        }
    }
    
    void FWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=mod(x+y);
                    a[j+m+k]=mod(x-y);
                }
        }
    }
    
    void IFWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=(x+y)*rev%MOD;
                    a[j+m+k]=(x-y)*rev%MOD;
                }
        }
    }
    
    int n,m;
    int a[N];
    vector<int> v[N];
    
    inline bool cmp(int x,int y)
    {
        return a[x]<a[y];
    }
    
    ll res[N];
    ll val[N][N];
    
    inline void dfs(int x,int fa)
    {
        for(int i=0;i<m;i++)
            val[x][i]=0;
        val[x][a[x]]=1;
        FWTxor(val[x],m);
        
        for(int i=0;i<v[x].size();i++)
        {
            int to=v[x][i];
            if(to==fa)
                continue;
            
            dfs(to,x);
            
            for(int j=0;j<m;j++)
                val[x][j]=(val[x][j]*(val[to][j]+1))%MOD;
        }
    }
    
    ll ans[N];
    
    int main()
    {
        Init();
        
        int T;
        scanf("%d",&T);
        while(T--)
        {
            for(int i=0;i<N;i++)
                v[i].clear(),ans[i]=0;
            
            scanf("%d%d",&n,&m);
            for(int i=1;i<=n;i++)
                scanf("%d",&a[i]);    
            for(int i=1;i<n;i++)
            {
                int x,y;
                scanf("%d%d",&x,&y);
                v[x].push_back(y);
                v[y].push_back(x); 
            }
            
            dfs(1,0);
            
            for(int i=1;i<=n;i++)
            {
                IFWTxor(val[i],m);
                for(int j=0;j<m;j++)
                    ans[j]=mod(ans[j]+val[i][j]);
            }
            
            for(int i=0;i<m;i++)
            {
                printf("%lld",ans[i]);
                if(i!=m-1)
                    putchar(' ');
            }
            printf("
    ");
        }
        
        return 0;
    }
    View Code

    最后看题解才做出来的:CF Gym 101955I ($Distance$ $Between$ $Sweethearts$,ICPC沈阳 2018)

    如果没有$max{left| I_{boy}-I_{girl} ight|, left| A_{boy}-A_{girl} ight|,left| G_{boy}-G_{girl} ight|}$,那么这题直接把$6$个数组FWT后卷起来就行了

    但是问题在于,在作差后,$I_{boy}$与$I_{girl}$、$A_{boy}$与$A_{girl}$、$G_{boy}$与$G_{girl}$之间有联系;所以应当对于每一属性分别考虑

    以$I$属性为例,我们显然可以用$O(n)$扫一遍,得出当$left| I_{boy}-I_{girl} ight|=x$($x$为一指定数)时,$I_{boy} ext{^} I_{girl}=i$($i$从$0 ext{~} 2048$)的数量

    接着,我们发现对三个差取max是一个比较麻烦的条件,于是可以考虑从小到大枚举这个max

    假设当前三个差中的最大值为$val$($0leq val< 2048$)

    对于每个属性先扫一遍,预处理出当$left| I_{boy}-I_{girl} ight|=val$时,$I_{boy} ext{^} I_{girl}=i$($i$从$0 ext{~} 2048$)的数量,记为$cur[i]$数组;类似的也可以得到$A$、$G$的信息

    然后对每个属性的信息分别做前缀和,得到$left| I_{boy}-I_{girl} ight|leq val$时,$I_{boy} ext{^} I_{girl}=i$($i$从$0 ext{~} 2048$)的数量,记为$pre[i]$数组;$A$、$G$类似

    然后考虑容斥:

    要保证三个差中的最大值为$val$,有三种情况——$1$个差为$val$、$2$个差为$val$、$3$个差为$val$

    对于$1$个差为$val$的情况,我们可以这样统计:要求一个属性$left| boy-girl ight|=val$,另两个$left| boy-girl ight|leq val$,即相当于将$1$个$cur[i]$数组与$2$个$pre[i]$数组做xor卷积

    对于$2$个差为$val$的情况:要求两个属性$left| boy-girl ight|=val$,另一个$left| boy-girl ight|leq val$,即相当于将$2$个$cur[i]$数组与$1$个$pre[i]$数组做xor卷积

    对于$3$个差为$val$的情况:要求三个属性$left| boy-girl ight|=val$,即相当于将$3$个$cur[i]$数组做xor卷积

    于是最大值为$val$时,$I_{boy} ext{^} I_{girl} ext{^} A_{boy} ext{^} A_{girl} ext{^} G_{boy} ext{^} G_{girl}=i$的数量$sum_i$可以通过充斥得出,则对答案的贡献为$sum_icdot (val ext{^} i)$

    常数有点大= =

    #include <algorithm>
    #include <cstring>
    #include <cstdio>
    using namespace std;
    
    typedef long long ll;
    const int N=1<<11;
    
    void FWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=x+y;
                    a[j+m+k]=x-y;
                }
        }
    }
    
    void IFWTxor(ll *a,int n)
    {
        for(int i=n;i>=2;i>>=1)
        {
            int m=i>>1;
            for(int j=0;j<n;j+=i)
                for(int k=0;k<m;k++)
                {
                    ll x=a[j+k],y=a[j+m+k];
                    a[j+k]=(x+y)/2;
                    a[j+m+k]=(x-y)/2;
                }
        }
    }
    
    int a[10];
    ll cur[3][N],pre[3][N],sum[N];
    
    int main()
    {
        int T;
        scanf("%d",&T);
        for(int it=1;it<=T;it++)
        {
            memset(pre,0,sizeof(pre));
            unsigned long long ans=0;
            for(int i=0;i<6;i++)
                scanf("%d",&a[i]);
            
            for(int i=0;i<N;i++)
            {
                memset(cur,0,sizeof(cur));
                memset(sum,0,sizeof(sum));
                
                for(int j=0;j<3;j++)
                {
                    for(int k=i;k<=a[j] && k-i<=a[3+j];k++)
                        cur[j][k^(k-i)]++;
                    if(i!=0)
                        for(int k=i;k<=a[3+j] && k-i<=a[j];k++)
                            cur[j][k^(k-i)]++;
                    
                    FWTxor(cur[j],N);
                    
                    for(int k=0;k<N;k++)
                        pre[j][k]+=cur[j][k];
                }
                
                for(int j=0;j<3;j++)
                    for(int k=0;k<N;k++)
                    {
                        ll tmp1=cur[j][k],tmp2=pre[j][k];
                        for(int l=0;l<3;l++)
                            if(j!=l)
                                tmp1*=pre[l][k],tmp2*=cur[l][k];
                        sum[k]+=tmp1-tmp2;
                    }
                
                for(int j=0;j<N;j++)
                {
                    for(int k=1;k<3;k++)
                        cur[0][j]*=cur[k][j];
                    sum[j]+=cur[0][j];
                }
                
                IFWTxor(sum,N);
                
                for(int j=0;j<N;j++)
                    ans+=sum[j]*(i^j);
            }
            
            printf("Case #%d: %llu
    ",it,ans);
        }
        
        return 0;
    }
    View Code

    数论变换(NTT)

    可以认为是能取模的FFT?

    最近有点不想看这个...效率低下 待填坑

    (待续)

  • 相关阅读:
    使用getattr() 分类: python基础学习 divide into python 2014-02-24 15:50 198人阅读 评论(0) 收藏
    使用locals()获得类,进行分发 分类: python 小练习 divide into python python基础学习 2014-02-21 14:51 217人阅读 评论(0) 收藏
    第1课第4.4节_Android硬件访问服务编写HAL代码
    第4.3节_Android硬件访问服务编写APP代码
    函数说明
    第1课第1节_编写第1个Android应用程序实现按钮和复选框
    vue生命周期
    程序代码中,怎么区分status和state?
    百度UEditor -- ZeroClipboard is not defined
    webstorm 设置ES6语法支持以及添加vuejs开发配置
  • 原文地址:https://www.cnblogs.com/LiuRunky/p/FFT_NTT_FWT.html
Copyright © 2011-2022 走看看