快速沃尔什变换FWT
概述
FFT 是加速这样的卷积 :
流程是这样的:
- (A, B Rightarrow FFT(A), FFT(B)) (带入单位根)
- (FFT(A) cdot FFT(B)) (点乘)
- (IFFT(FFT(A) cdot FFT(B)) Rightarrow A * B)
而对于位运算的卷积, (oplus) 暂指异或
通过构造相应的 (FWT(A)), 尝试以同样的步骤快速得到卷积
构造的流程大概是这样的:
- 构造 FWT 的通式
- 得到 FWT 分治递归式
- 通过上式得到 IFWT 的分治递归式
时刻注意 FWT 是系数转点值的作用, IFWT 相反,
然后诸多性质可以感性得到, 因为 FWT 变换没有 FFT 那样复杂, 就不需要大量公式的推导
or 卷积
构造这样的 FWT :
来验证一下:
- ((A | B)_i) 由下标为 (i) 的子集的 ((A|B)_j) 相加, 而他们由两个或起来是 (j) 的两个下标相乘, 下标也是 (i) 的子集.
- 下标是 (i) 的子集的 (A_j cdot B_k), 或起来只能组成 (i) 的子集.
- 不难看出 1 和 2 中所指的两个下标互为充要条件
再形象一点?
正向逆向都走得通
(先跳到 and 卷积的 FWT 构造)
接着看如何分治求 FWT 和 IFWT:
将 (A) 按最高二进制位是 0/1 分成 (A_0, A_1), 其实就是前一半后一半
根据通式, (FWT(A_0)) 只能来自 (A_0), 也就是说 (FWT(A)) 的前一半就是 (FWT(A_0)),
对于 (FWT(A_1)), 一部分来自最高位不是 0 的下标, 另一部分反之, 前者即为 (A_0), 后者为 (A_1),
考虑 (A_0) 和 (A_1) 去掉最高位后依次对应, 易得 (FWT(A)_{n / 2 + i} = FWT(A_0)_i + FWT(A_1)_i)
即, (FWT(A)) 前一半为 (FWT(A_0)), 后一半为 (FWT(A_0) + FWT(A_1))
然后考虑 IFWT:
同样 (A_0, A_1), 现在将一组点值转回去
显然前一半的点值直接由 (IFWT(A_0)) 得到,
后一半系数呢? 既然合法的点值与系数组组对应, 那么可以有它生成的点值得到, 即 (IFWT(A_1 - A_0)),
然而这个式子并不和谐, 考虑 "一组系数带入一个值 - 另一组系数带入这个值 = (两组系数相减)带入这个值", (IFWT(A_1 - A_0) = IFWT(A_1) - IFWT(A_0))
即, (IFWT(A)) 前一半为 (IFWT(A_0)), 后一半为 (IFWT(A_1) - IFWT(A_0))
and 卷积
模仿 or 来构造 FWT
同样是符合这个的
证明方法类似
接下来还是很像(想一想)
xor 卷积
很显然这个 FWT 的通式不是之前的套路
先咕着... https://blog.csdn.net/xyyxyyx/article/details/103564869
其中 ({|x|}) 表示二进制 1 位的个数
考虑递归式 FWT:
先考虑 (A_0), (FWT(A_0)) 不用说了, 但是 (j) 也可以取右半边的, 那么对应位的 (A_1) 比自己多一位了, 分类讨论可得 (FWT(A_1)) 贡献正负性不变
在考虑 (A_1), 同理, 但是 (A_0) 比自己少一位, 分类讨论可得 (FWT(A_0)) 的正负性不变, 但是自己的 (FWT(A_1)) 贡献正负性变了
即, (FWT(A) = merge(FWT(A_0) + FWT(A_1), FWT(A_0)) - FWT(A_1)))
然后很容易得到
代码
luogu 的板子题
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
inline LL in()
{
LL x = 0, flag = 1; char ch = getchar();
while (!isdigit(ch)) { if (ch == '-') flag = -1; ch = getchar(); }
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
return x * flag;
}
typedef double LDB;
const int MAXN = (1 << 20) + 10;
int lowbit(int x) { return x & (-x); }
int n, len;
const LL MOD = 998244353, inv2 = 499122177;
void FWT_or(LL * a, int len, int sgn)
{
// FWT(A) = merge(FWT(A0) + FWT(A1), FWT(A1)) ;
for (int mid = 1; mid < len; mid <<= 1)
for (int i = 0; i < len; i += (mid << 1))
for (int j = 0; j < mid; ++ j)
{
if (sgn == -1) (a[i + j + mid] += MOD - a[i + j]) %= MOD;
else (a[i + j + mid] += a[i + j]) %= MOD;
}
}
void FWT_and(LL * a, int len, int sgn)
{
// FWT(A) = merge(FWT(A0, FWT(A0) + FWT(A1)) ;
for (int mid = 1; mid < len; mid <<= 1)
for (int i = 0; i < len; i += (mid << 1))
for (int j = 0; j < mid; ++ j)
{
if (sgn == -1) (a[i + j] += MOD - a[i + j + mid]) %= MOD;
else (a[i + j] += a[i + j + mid]) %= MOD;
}
}
void FWT_xor(LL * a, int len, int sgn)
{
// FWT(A) = merge(FWT(A0) + FWT(A1), FWT(A0) - FWT(A1)) ;
// IFWT(A) = merge((IFWT(A0) + IFWT(A1)) / 2, (IFWT(A0) - IFWT(A1)) / 2) ;
for (int mid = 1; mid < len; mid <<= 1)
for (int i = 0; i < len; i += (mid << 1))
for (int j = 0; j < mid; ++ j)
{
LL x = a[i + j], y = a[i + j + mid];
a[i + j] = (x + y) % MOD, a[i + j + mid] = (x + MOD - y) % MOD;
if (sgn == -1)
(a[i + j] *= inv2) %= MOD, (a[i + j + mid] *= inv2) %= MOD;
}
}
LL reca[MAXN], recb[MAXN];
LL a[MAXN], b[MAXN];
void solve_or()
{
for (int i = 0; i < len; ++ i) a[i] = reca[i], b[i] = recb[i];
FWT_or(a, len, 1); FWT_or(b, len, 1);
for (int i = 0; i < len; ++ i) (a[i] *= b[i]) %= MOD;
FWT_or(a, len, -1);
for (int i = 0; i < len; ++ i) printf("%lld ", a[i]); puts("");
}
void solve_and()
{
for (int i = 0; i < len; ++ i) a[i] = reca[i], b[i] = recb[i];
FWT_and(a, len, 1); FWT_and(b, len, 1);
for (int i = 0; i < len; ++ i) (a[i] *= b[i]) %= MOD;
FWT_and(a, len, -1);
for (int i = 0; i < len; ++ i) printf("%lld ", a[i]); puts("");
}
void solve_xor()
{
for (int i = 0; i < len; ++ i) a[i] = reca[i], b[i] = recb[i];
FWT_xor(a, len, 1); FWT_xor(b, len, 1);
for (int i = 0; i < len; ++ i) (a[i] *= b[i]) %= MOD;
FWT_xor(a, len, -1);
for (int i = 0; i < len; ++ i) printf("%lld ", a[i]); puts("");
}
int main()
{
n = in();
len = 1 << n;
for (int i = 0; i < len; ++ i) reca[i] = in();
for (int i = 0; i < len; ++ i) recb[i] = in();
solve_or();
solve_and();
solve_xor();
return 0;
}