zoukankan      html  css  js  c++  java
  • 快速沃尔什变换(FWT)学习笔记 + 洛谷P4717 [模板]

    FWT求解的是一类问题:( a[i] = sumlimits_{jigoplus k=i}^{} b[j]*c[k] )

    其中,( igoplus ) 可以是 or,and,xor

    三种问题的解决思路都是对多项式 ( a ) 构造一个 ( a' ),令 ( a' = b' * c' );

    那么只需要把 ( b ) 变换成 ( b' ),( c ) 变换成 ( c' ),然后乘出 ( a' ),再逆变换得到 ( a );

    下面问题就变成如何快速(logn)求 ( b ) 到 ( b' ) 的变换,这个变换就是 FWT;

    始终要记住进行位运算的是位置(角标)而不是值;

    一、or

    构造 ( a'[i] = sumlimits_{j|i=i}^{} a[j] )

    1.正变换

    考虑把 ( a ) 分成前后两个部分 ( a0 ) 和 ( a1 ),先分别递归下去做好,得到 ( a0' ) 和 ( a1' );

    可以发现,( a0' ) 和 ( a1' ) 的位置(角标)数字上唯一不同就是最高位是0或1;

    但递归下去做的时候,( a0' ) 和 ( a1' ) 的位置数字相当与去掉了最高位(因为折半了);

    所以合并的时候,关键要考虑到最高位的0和1的不同:

    (1) 对于 ( a' ) 的一个位置 ( i ) ,如果它在前半部分,那么它可以直接继承 ( a0'[i] );

    而 ( a1'[i]) 由于实际上 ( i ) 还应该加上最高位的1,or 运算使它能贡献的位置最高位也是1,但 ( i ) 的最高位是0,所以不贡献给 ( a'[i] ) ;

    (2)对于后半部分的 ( i ) ,( a0'[i] ) 和 ( a1'[i] ) 都会对它产生贡献,因为两部分的位置数字都是 ( i ) 的子集;

    所以可以得到:( a' = left ( a0' , a0'+a1' ight ) )

    递归的底层,只有一个元素的时候,( a = a' ) ,于是我们可以递归做出正变换了;

    当然,仿照 FFT 的写法即可,并不需要真的写递归函数,而且也不用蝴蝶变换;

    2.逆变换

    同样先考虑两个部分 ( a'0 ) 和 ( a'1 ) ,表示 ( a' ) 的前后部分;

    已经做了 ( a' = left ( a0' , a0'+a1' ight ) )

    现在要从 ( a' ) 拆出 ( a0' ) 和 ( a1' )

    那么 ( a0' = a'0 )

    而且 ( a1' = a'1 - a'0 )

    知道了 ( a0' ) 和 ( a0' ) ,就可以继续递归求解 ( a0 ) 和 ( a1 ),二者合起来就可以得到 ( a )

    递归的底层,只有一个元素的时候,( a' = a ) ,于是我们可以递归作出逆变换了;

    void fwt1(int *a,int tp)//a'=(a0',a0'+a1')  //a=(a0',a1'-a0')
    {
      for(int mid=1;mid<lim;mid<<=1)
        for(int j=0,len=(mid<<1);j<lim;j+=len)
          for(int k=0;k<mid;k++)
          a[j+mid+k]=upt(a[j+mid+k]+tp*a[j+k]);
    }
    or

    二、and

    构造 ( a' = sumlimits_{j & i=i}^{} a[j] )

    1.正变换

    和 or 同理,考虑最高位01的不同,后面继承本身,而前面要加上后面的贡献;

    得到 ( a' = left ( a0'+a1' , a1' ight ) )

    2.逆变换

    同理,得到 

    ( a0' = a'0 - a'1 )

    ( a1' = a'1 )

    void fwt2(int *a,int tp)//a'=(a0'+a1',a1')  //a=(a0'-a1',a1')
    {
      for(int mid=1;mid<lim;mid<<=1)
        for(int j=0,len=(mid<<1);j<lim;j+=len)
          for(int k=0;k<mid;k++)
          a[j+k]=upt(a[j+k]+tp*a[j+mid+k]);
    }
    and

    三、xor

    设 ( d(i,j) ) 表示 ( i&j ) 二进制表示中1的个数;

    构造 ( a' = sumlimits_{d(i,j)\%2==0}^{} a[j] - sumlimits_{d(i,j)\%2==1}^{} a[j] )

    1.正变换

    让我们三步走:

    (1) ( a' = left ( a0' + a1' , a0' - a1' ight ) )

    首先明确,( a' ) 是 ( d(i&j) ) 为偶数的 ( a[j] ) 求和,减去 ( d(i&j) ) 为奇数的 ( a[j] ) 求和;

    <1> 对于整体的一个位置 ( i ),它在前半部分

    对于前半部分(折半)的相同位置 ( i' ),在前半部分的 ( j ) 中,( d(i'&j) ) 的奇偶性和 ( d(i&j) ) 一样,所以继承答案;

    对于后半部分(折半)的相同位置 ( i' ),在后半部分的 ( j ) 中,计算 ( i'&j ) 时是没有考虑最高位的,所以它们的最高位上都是0,

    而因为 ( i ) 的最高位是0,( i&j ) 的最高位同样是0,所以正好符合,答案可以加上;

    也就是,( a' = left ( a0' + a1' , ... ight ) )

    <2> 对于整体的一个位置 ( i ),它在后半部分

    对于前半部分(折半)的相同位置 ( i' ),在前半部分的 ( j ) 中,( d(i'&j) ) 的最高位都是0,

    而因为 ( j ) 的最高位是0,( i&j ) 的最高位同样是0,所以正好符合,答案可以加上;

    对于后半部分(折半)的相同位置 ( i' ),在后半部分的 ( j ) 中,计算 ( i'&j ) 时是没有考虑最高位的,所以它们的最高位上都是0,

    但 ( i&j ) 的最高位是1,所以奇偶性都反了,答案加上的是负的;

    这样,就得到 ( a' = left ( a0' + a1' , a0' - a1' ight ) )

     

    (2) ( d(i&k) otimes d(j&k) = d( (i otimes j)&k ) )

    因为是 ( & ) ,我们就看 ( k ) 是1的那些位;

    如果 ( d(i&k) ) 是偶数,说明 ( i&k ) 有偶数个1和 ( k ) 重合,奇数同理,( j ) 同理;

    <1> 当 ( d(i&k) ) 和 ( d(j&k) ) 奇偶性相同时

    ( d(i&k) + d(j&k) ) 是偶数;

    而 ( i otimes j ) 同时消去 ( i ) 和 ( j ) 相同位置的1,不是 ( k ) 的1就算了,是 ( k ) 的1,消去的也是偶数;

    所以 ( d( (i otimes j)&k ) ) 是偶数;

    <2> 当 ( d(i&k) ) 和 ( d(j&k) ) 奇偶性不同时

    ( d(i&k) + d(j&k) ) 是奇数;

    而 ( i otimes j ) 同时消去 ( i ) 和 ( j ) 相同位置的1,不是 ( k ) 的1就算了,是 ( k ) 的1,消去的是偶数;

    所以 ( d( (i otimes j)&k ) ) 是奇数;

    这样我们就证明了 ( d(i&k) otimes d(j&k) = d( (i otimes j)&k ) )

     

    (3) 若 ( c[i] = sum_{j otimes k=i}^{} a[j]*b[k] ) ,有 ( c' = a' * b' )

    因为 ( c[i] = sum_{j otimes k=i}^{} a[j]*b[k] )

    又 ( c' = sum_{d(i,j)\%2==0}^{} c[j] - sum_{d(i,j)\%2==1}^{} c[j] )

    代入,得到 ( c'[i] = sum_{d((j otimes k)&i)\%2==0}^{} a[j]*b[k] - sum_{d((j otimes k)&i)\%2==1}^{} a[j]*b[k] )

    而 ( a'[i] * b'[i] = ( sum_{d(i,j)\%2==0}^{} a[j] - sum_{d(i,j)\%2==1}^{} a[j]) * ( sum_{d(i,j)\%2==0}^{} b[j] - sum_{d(i,j)\%2==1}^{} b[j]) )

    拆开再组合,并使用(2)得到的结论,就得到 ( a'[i] * b'[i] = ( sum_{d((j otimes k)&i)\%2==0}^{} a[j]*b[k] - sum_{d((j otimes k)&i))\%2==1}^{} a[j]*b[k]) )

    所以 ( c' = a' * b' )

    综上,我们仍然可以递归求 xor 的正变换,( a' = left ( a0' + a1' , a0' - a1' ight ) )

    2.逆变换

    根据正变换就可以知道咯:

    ( a0' = (a'0 + a'1) / 2 )

    ( a1' = (a'0 - a'1) / 2 )

    void fwt3(int *a,int tp)//a'=(a0'+a1',a0'-a1')  //a=((a0'+a1')/2,(a0'-a1')/2)
    {
      for(int mid=1;mid<lim;mid<<=1)
        for(int j=0,len=(mid<<1);j<lim;j+=len)
          for(int k=0;k<mid;k++)
        {
          int x=a[j+k],y=a[j+mid+k];
          a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
          if(tp==-1)a[j+k]=(ll)a[j+k]*inv%mod,a[j+mid+k]=(ll)a[j+mid+k]*inv%mod;
        }
    }
    xor

    看例题:https://www.luogu.org/problemnew/show/P4717

    代码如下:

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    typedef long long ll;
    int const xn=(1<<17),mod=998244353;
    int n,a[xn],b[xn],f[xn],g[xn],lim,inv;
    int rd()
    {
      int ret=0,f=1; char ch=getchar();
      while(ch<'0'||ch>'9'){if(ch=='0')f=0; ch=getchar();}
      while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
      return f?ret:-ret;
    }
    ll pw(ll a,int b)
    {
      ll ret=1;
      for(;b;b>>=1,a=(a*a)%mod)
        if(b&1)ret=(ret*a)%mod;
      return ret;
    }
    int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
    void fwt1(int *a,int tp)//a'=(a0',a0'+a1')  //a=(a0',a1'-a0')
    {
      for(int mid=1;mid<lim;mid<<=1)
        for(int j=0,len=(mid<<1);j<lim;j+=len)
          for(int k=0;k<mid;k++)
          a[j+mid+k]=upt(a[j+mid+k]+tp*a[j+k]);
    }
    void fwt2(int *a,int tp)//a'=(a0'+a1',a1')  //a=(a0'-a1',a1')
    {
      for(int mid=1;mid<lim;mid<<=1)
        for(int j=0,len=(mid<<1);j<lim;j+=len)
          for(int k=0;k<mid;k++)
          a[j+k]=upt(a[j+k]+tp*a[j+mid+k]);
    }
    void fwt3(int *a,int tp)//a'=(a0'+a1',a0'-a1')  //a=((a0'+a1')/2,(a0'-a1')/2)
    {
      for(int mid=1;mid<lim;mid<<=1)
        for(int j=0,len=(mid<<1);j<lim;j+=len)
          for(int k=0;k<mid;k++)
        {
          int x=a[j+k],y=a[j+mid+k];
          a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
          if(tp==-1)a[j+k]=(ll)a[j+k]*inv%mod,a[j+mid+k]=(ll)a[j+mid+k]*inv%mod;
        }
    }
    int main()
    {
      n=rd(); lim=(1<<n); inv=pw(2,mod-2);
      for(int i=0;i<lim;i++)a[i]=f[i]=rd();
      for(int i=0;i<lim;i++)b[i]=g[i]=rd();
      fwt1(f,1); fwt1(g,1);
      for(int i=0;i<lim;i++)f[i]=(ll)f[i]*g[i]%mod;
      fwt1(f,-1);
      for(int i=0;i<lim;i++)printf("%d ",f[i]); puts("");
    
      for(int i=0;i<lim;i++)f[i]=a[i],g[i]=b[i];
      fwt2(f,1); fwt2(g,1);
      for(int i=0;i<lim;i++)f[i]=(ll)f[i]*g[i]%mod;
      fwt2(f,-1);
      for(int i=0;i<lim;i++)printf("%d ",f[i]); puts("");
    
      for(int i=0;i<lim;i++)f[i]=a[i],g[i]=b[i];
      fwt3(f,1); fwt3(g,1);
      for(int i=0;i<lim;i++)f[i]=(ll)f[i]*g[i]%mod;
      fwt3(f,-1);
      for(int i=0;i<lim;i++)printf("%d ",f[i]); puts("");
      return 0;
    }

    参考博客:https://www.cnblogs.com/ACMLCZH/p/8022502.html

    https://blog.csdn.net/neither_nor/article/details/60335099

  • 相关阅读:
    第二十天笔记
    第十九天笔记
    第十七天笔记
    第十五天笔记
    第十六天笔记
    第十二天笔记
    数字三角形
    最大子段和与最大子矩阵和
    分组背包
    二维背包
  • 原文地址:https://www.cnblogs.com/Zinn/p/10037445.html
Copyright © 2011-2022 走看看