用来解决在下标中进行位运算的卷积
具体形式就是求
思路大概就是把序列 (a) 变换为 (fwt(a)),(b,c) 同理,使得 (fwt(c)=fwt(a) fwt(b)),这样得到了 (fwt(c)) 再变换回来
或运算
构造 (fwt(a)_i=sum_{i|j=i} a_j)
也就是,取所有下标是 (i) 的二进制位中子集的 (a_j) 的和
由于有 (i|j=i,i|k=iRightarrow i|(j|k)=i),所以:
于是考虑如何求 (fwt(a)),考虑分治,用 (a_0,a_1) 分别表示 (a) 序列中下标第一个二进制位为 (0/1) 的情况(各 (2^{n-1}) 个数),则有:
(operatorname{merge}) 为链接,加号为每一位分别相加
就是说由于是求下标是它子集的元素的和,那么 (a_1) 是可以将第一个二进制位改为 (0),得到它的一个子集,也就是包含了 (a_0)
再考虑如何由 (fwt(a)) 求出 (a),其实直接反过来就好:
与运算
其实和或运算类似,构造 (fwt(a)_i=sum_{i&j=i} a_j)
然后分治的时候,可以把 (a_0) 的第一个二进制位改为 (1),包含上 (a_1),于是:
异或
稍微复杂一些,不能用子集的关系表示了
设 (f(i,j)=operatorname{popcount}(i&j) mod 2)
有:(f(i,j)operatorname{xor}f(i,k)=f(i,joperatorname{xor}k))
证明大概就是,因为是先与运算再统计二进制中 (1) 的个数,所以只用考虑 (i) 为 (1) 的那几位,如果 (j,k) 在这些位上也是 (1) 的个数的奇偶性相同,那么他们中有一部分是重叠的会被异或掉,剩下的显然奇偶性仍然相同,那么总共偶数个,结果为 (0)
如果奇偶性不同,那么重叠的部分被异或掉以后,剩下的奇偶性仍然不同,总共奇数个,结果为 (1)
那么此时就可以构造:
那么相乘就是
然后考虑分治的时候如何计算,有:
原理上,就是对于前 (2^{n-1}) 个数,最高位是 (0),由于 (0&0=0&1=0),对 (f) 的结果没有影响,直接简单相加
后 (2^{n-1}) 个数,最高位是 (1),由于 (1&0=0,1&1=1),使得 (f) 结果改变,应为 (-a_1)
模板题:https://www.luogu.com.cn/problem/P4717
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#define reg register
#define LL_INF (long long)(0x3f3f3f3f3f3f3f3f)
#define INT_INF (int)(0x3f3f3f3f)
inline int read(){
register int x=0;register int y=1;
register char c=std::getchar();
while(c<'0'||c>'9'){if(c=='-') y=0;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+(c^48);c=getchar();}
return y?x:-x;
}
#define mod 998244353
#define inv 499122177
#define N 131078
int n;
long long A[N],B[N],C[N];
long long a[N],b[N],c[N];
inline void OR(long long *f,long long *a,int x){
for(reg int i=0;i<n;i++) f[i]=a[i];
for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++)
f[i+j+k]=(f[i+j+k]+mod+(x?f[i+j]:-f[i+j]))%mod;
}
inline void AND(long long *f,long long *a,int x){
for(reg int i=0;i<n;i++) f[i]=a[i];
for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++)
f[i+j]=(f[i+j]+mod+(x?f[i+j+k]:-f[i+j+k]))%mod;
}
inline void XOR(long long *f,long long *a,int x){
for(reg int i=0;i<n;i++) f[i]=a[i];
for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++){
f[i+j]=(f[i+j]+f[i+j+k])%mod;
f[i+j+k]=(f[i+j]-f[i+j+k]-f[i+j+k]+mod+mod)%mod;
if(!x) f[i+j]=f[i+j]*inv%mod,f[i+j+k]=f[i+j+k]*inv%mod;
}
}
inline void calc(long long *a,long long *b,long long *c){
for(reg int i=0;i<n;i++) c[i]=a[i]*b[i]%mod;
}
int main(){
n=(1<<read());
for(reg int i=0;i<n;i++) A[i]=read();
for(reg int i=0;i<n;i++) B[i]=read();
OR(a,A,1);OR(b,B,1);calc(a,b,c);OR(C,c,0);
for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
AND(a,A,1);AND(b,B,1);calc(a,b,c);AND(C,c,0);
for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
XOR(a,A,1);XOR(b,B,1);calc(a,b,c);XOR(C,c,0);
for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
return 0;
}