概述
FWT的大体思路就是把要求的 C(x)=A(x)×B(x) 即 ( c[i]=sumlimits_{j?k=i} (a[j]*b[k]) ) 变换成这样的:( c^{'}[i]=a^{'}[i]*b^{'}[i] )。
只要知道 c'[ i ] 和 c[ i ] 的关系,就能把 A(x)、B(x) 变成 A'(x)、B'(x) ,从而算出 C'(x) ,再把 C'(x) 变成 C(x)。
或卷积
定义( c^{'}[i]=sumlimits_{j | i=i} c[j] ),则( c^{'}[i]=a^{'}[i]*b^{'}[i] )
证明:( c^{'}[i]=sumlimits_{j | i=i} c[j] )
因为( c[i]=sumlimits_{j|k=i}a[j]*b[k] )
所以 ( c^{'}[i]=sumlimits_{(j|k)|i=i}a[j]*b[k] )
( =sumlimits_{(j|k)|i=i}a[j]*b[k]=sumlimits_{j|i=i}a[j] * sumlimits_{k|i=i}b[k] )
又有:( a^{'}[i]=sumlimits_{j|i=i}a[j] ) ( b^{'}[i]=sumlimits_{j|i=i}b[j] )
所以( c^{'}[i]=a^{'}[i]*b^{'}[i] )
接下来考虑怎么把 A(x) 变成 A'(x) 。
考虑按位来做,比如从低位到高位枚举,则每一部分左边一半的该位全是0、右边一半的该位全是1;记左边为A0,右边为A1。
如:00001111
00110011
01010101
如果已经算好了A0和A1,考虑用它们求出A。比如算好了第 1~2 位置的值A0和第 3~4 位置的值A1,想求第 1~4 位置的值A;(那么现在是枚举到了二进制第二位了)
此时的A0里没有A1位置对它的贡献,A1里也没有A0位置对它的贡献;考虑两部分位置的值怎样互相贡献;
考虑左边和右边的对应位置,它们只有最高位一个是0一个是1的不同;则是左边对应位置的子集的位置一定也是右边对应位置的子集,可以这样做:A0' = A0,A1' = A0+A1
所以模仿FFT的框架写一个就行了。(感觉这里求 ( sumlimits_{j|i=i}a[j] ) 的思路和高维前缀和很像)
但不用弄那个 r[ ] 来换位置(因为不是弄偶数项和奇数项,而是真的前半部分和后半部分);不过就算换了位置也可以!
因为那样换一下位置相当于是每个位置的角标被翻转了,比如上面那8个位置的角标会变成:
01010101
00110011
00001111
这样的话,自己“从低位到高位枚举”可以看作从高位到低位枚举,一切就没问题了。主要是因为位运算每一位是独立的嘛。
它的逆变换是这样想:因为 A0' = A0,A1' = A0+A1;所以 A0 = A0',A1 = A1'-A0 = A1'-A0'。刚才是从低位到高位枚举的话,现在要从高位到低位枚举。
但其实还是从低位到高位枚举也是对的!
考虑一个位置k,它加上的那些 “对应位置” j 的特点是 j 只和 k 有一位不同。比如从低到高枚举到第3位的时候 k 位置的值加上了 j 位置的值,说明二进制第3位上 j 是0、k是1,第3位之前 j 和 k 一样(因为“对应”嘛),而第3位之后 j 和 k 其实也一样(因为第3位之后 j 和 k 就变成“一块"里的了,再高的位会一起变成0或1之类的);
从 A'(x) 变回 A(x) 的过程中,比如第一步的时候,每个 a[ i ] 都记录着所有 角标是 i 子集的a的权值和 ;
从低到高枚举到第一个 k 是1的位置(除了最低位),比如是第3位,则此时 a'[ k ] - a'[ j ] 减去的值是 “角标第3位是0、其余部分是 k 的子集” 的那些位置的值;剩下的就是 “角标第3位是1、其余部分是 k 的子集” 的值。
接下来枚举到下一个 k 是1的位置,比如是第5位;因为 j 的其它位上的值都和 k 一样,所以此时 j 也是经历过第3位时的一番操作;则此时 a'[ k ] - a'[ j ] 减去的值是“角标第3位是1、第5位是0、其余部分是 k 的子集”的那些位置的值;则 a'[ k ] 剩下的值是 “角标第3位是1、第5位是1、其余部分是 k 的子集” 的位置的值;
这样一直枚举到最后,剩下的就是 “角标在 k 是1的位上是1、其余位上是 k 的自己” 位置的值,即只剩正好的 a[ k ] 了,于是此时 a'[ k ] = a[ k ] 。
与卷积
和或卷积一样。变换:A0'=A0+A1,A1'=A1 逆变换:A0 = A0'-A1 = A0'-A1',A1=A1'
异或卷积
定义 ( c^{'}[i]=sumlimits_{j & i有偶数个1} c[j] - sumlimits_{j & i有奇数个1} c[j] )
考虑证明 ( c^{'}[i]=a^{'}[i]*b^{'}[i] )
证明:因为 ( c[i]=sumlimits_{j otimes k=i} a[j]*b[k] )
所以 ( c^{'}[i]=sumlimits_{ (j otimes k)与 i 有偶数个1重合 } a[j]*b[k] - sumlimits_{ (j otimes k)与 i 有奇数个1重合 } a[j]*b[k] )
又 ( a^{'}[i]*b^{'}[i] = ( sumlimits_{j & i有偶数个1}a[j] - sumlimits_{j & i有奇数个1}a[j] ) * ( sumlimits_{j & i有偶数个1}b[j] - sumlimits_{j & i有奇数个1}b[j] ) )
( = sumlimits_{j & i有偶数个1}a[j]*b[j] + sumlimits_{j & i有奇数个1}a[j]*b[j] - sumlimits_{j & i有偶数个1,k & i有奇数个1}a[j]*b[k] - sumlimits_{j & i有奇数个1,k & i有偶数个1}a[j]*b[k] )
( = sumlimits_{j & i与k & i的1的个数奇偶性相同}a[j]*b[k] - sumlimits_{j & i与k & i的1的个数奇偶性不同}a[j]*b[k] )
( = sumlimits_{(j otimes k)与 i 有偶数个1重合}a[j]*b[k] - sumlimits_{(j otimes k)与 i 有奇数个1重合}a[j]*b[k] )
(这一步等价是因为异或的时候,如果 j 和 k 有公共位置的1,那么一次会消掉2个1;所以 ( (j&i)的1的个数 + (k&i)的1的个数 ) 在 j 和 k 异或之后奇偶性不会变)
所以 ( c^{'}[i]=a^{'}[i]*b^{'}[i] )
接下来考虑怎么把A(x)变成A'(x)。
还是有前一半的A0和后一半的A1。对应位置 & 起来之后,那个最高位还是0;
所以对于A0里的一个a[ i ]来说,记和它 & 起来的那些位置 j (其实 j 遍历了所有A0里的位置)在A1里的对应位置为 j' ,则 j & i == j' & i;所以A0'=A0+A1;
而对于A1里的一个a[ i ]来说,算A1的时候A1的标号的最高位还没被考虑(即视作0),合并的时候A1的最高位变成1了;设 i 在A1的 & 起来的那些位置 j 在A0里的对应位置为 j',则 j & i 比 j' & i 多了一个1(最高位即当前枚举到的位),所以当 i 和A1里的 j 匹配时,单独算A1时算好的 a[ i ] = sigma - sigma 里的两个 sigma 的位置换了一下,也就是符号变了;所以A1'=A0-A1。
它的逆变换就是:A0=(A0'+A1')/2,A1=(A0'-A1')/2。
关于实现方法的讨论就和或卷积一样。
模板
洛谷4717 【模板】快速沃尔什变换
题目:https://www.luogu.org/problemnew/show/P4717
不知为何跑得很慢。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=(1<<18)+5,mod=998244353; int n,a[3][N],b[3][N],c[3][N],len,r[N],inv; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9') ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return fx?ret:-ret; } int g[20]; void wrt(int x) { if(x<0)putchar('-'),x=-x; if(!x){printf("0 ");return;} int t=0;while(x)g[++t]=x%10,x/=10; while(t)putchar(g[t]+'0'),t--;putchar(' '); } void upd(int &x){x>=mod?x-=mod:0;} void fwt0(int *a,bool fx) for(int R=2;R<=len;R<<=1) for(int i=0,m=R>>1;i<len;i+=R) for(int j=0;j<m;j++) (fx?a[i+m+j]+=mod-a[i+j]:a[i+m+j]+=a[i+j]),upd(a[i+m+j]); } void fwt1(int *a,bool fx) for(int R=2;R<=len;R<<=1) for(int i=0,m=R>>1;i<len;i+=R) for(int j=0;j<m;j++) (fx?a[i+j]+=mod-a[i+m+j]:a[i+j]+=a[i+m+j]),upd(a[i+j]); } void fwt2(int *a,bool fx) for(int R=2;R<=len;R<<=1) { for(int i=0,m=R>>1;i<len;i+=R) for(int j=0;j<m;j++) { int x=a[i+j]+a[i+m+j],y=a[i+j]+mod-a[i+m+j]; upd(x); upd(y); fx?(x=(ll)x*inv%mod,y=(ll)y*inv%mod):0; a[i+j]=x; a[i+m+j]=y; } } } int main() { n=(1<<rdn()); for(int i=0;i<n;i++)a[0][i]=a[1][i]=a[2][i]=rdn(); for(int i=0;i<n;i++)b[0][i]=b[1][i]=b[2][i]=rdn(); len=n<<1; for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0); int k=mod-2,tmp=2;inv=1; while(k){if(k&1)inv=(ll)inv*tmp%mod;tmp=(ll)tmp*tmp%mod;k>>=1;} fwt0(a[0],0); fwt0(b[0],0); fwt1(a[1],0); fwt1(b[1],0); fwt2(a[2],0); fwt2(b[2],0); for(int t=0;t<3;t++) for(int i=0;i<len;i++) c[t][i]=(ll)a[t][i]*b[t][i]%mod; fwt0(c[0],1); fwt1(c[1],1); fwt2(c[2],1); for(int t=0;t<3;t++,puts("")) for(int i=0;i<n;i++)wrt(c[t][i]); return 0; }