构造 DFT :某少女附中的体育课 题解
题意
神仙题。一题能当十题写。
一句话题意:给出一个下标系统 (A),求序列 (P) 在该下标系统下的 (N) 维卷积的 (K+1) 次幂。
先来解释一下什么意思。
一般来说,卷积是一个如下的形式
当 (circ) 为 (+) 的时候,就是我们熟悉的多项式乘法。
那如果是位运算呢?就是 ( ext{FWT}) 干的事情了。
那如果是其他的东西呢?这个题就是给了你一个 (m imes m) 的矩阵 (A) ,表示一个运算表。
然后你有一个长度为 (m^n) 的序列 (P) ,每个下标是一个 (n) 位 (m) 进制数,下标之间的运算就是按位进行矩阵 (A) 下的运算。你需要在这个意义下算出 (P) 的 (K+1) 次幂。
为了方便,以下设 (N = m^n) 。
算法一
直接暴力计算卷积即可。
预计得分: (7)。
算法二
我们注意到下标运算系统其实就是同或。直接上同或 ( ext{FWT}) 就行。
什么,你不会同或 ( ext{FWT})?没事,我也不会。你只需要会异或 ( ext{FWT}) 就行了。
我们注意到两个数的同或其实就是异或值的取反,所以做完异或 ( ext{FWT}) 之后 std::reverse
一下就行了。
但是这是做一遍卷积,两遍卷积就相当于又 std::reverse
回来了。所以 (K+1) 为奇数的时候才需要 std::reverse
。
然后题目保证了 (A) 是关于主对角线对称的(交换律),于是 (m=2) 的情况其实只有四种 (A) ,分别对应与,或,同或,异或,直接上四种 ( ext{FWT}) 就行了。
预计得分:(25) 。
前二十五分并不需要什么冷门知识。
算法三
剩下的就有点意思了。
Subtask4
中的 (A) 表示按位取 max
,类比 ( ext{FWT}) ,我们只需要将 (0) 的位置的贡献加到 (1) 上,然后再将 (1) 的位置的贡献加到 (2) 上,写起来和 $ ext{OR FWT}$ 差不多,就是个高维前缀和。
预计得分: (34) 。
算法四
为什么写起来和 ( ext{FWT}) 差不多呢?
vfk
的论文里说,( ext{FWT}) 本质上是一个高维的 ( ext{DFT}) 。而什么又是高维 ( ext{DFT}) 呢?
注意, ( ext{FFT}) 是快速求 ( ext{DFT}) 的算法,而 ( ext{DFT}) 是一个变换,不要搞混。
先来回顾一下 ( ext{DFT}) 。
我们知道 ( ext{DFT}) 是求出了一个多项式在单位根处的点值表示。为什么是主单位根呢?因为 ( ext{DFT}) 最核心的原理是求和引理
其中 (v) 是定值。这个式子可以用等比数列求和公式来证明。
(v mod n ot = 0) 时,有 (w^v ot= 1):
(v mod n = 0) 时,有 (omega^v = 1):
回顾一下我们多项式乘法卷积的式子:
我们先悄咪咪的将 ([p+q=r]) 变成 ([(p+q)mod n=r])
于是我们就可以用求和引理了!
然后我们继续化式子
于是我们只需要定义 ( ext{DFT}) 为
定义 ( ext{IDFT}) 为
于是我们就有了一个卷积性变换
(O(n^2) o O(nlog n)) 。
慢着,我们刚刚悄咪咪的将 ([p+q=r]) 变成 ([(p+q)mod n=r]) 真的没问题嘛?
当然有啊!因为原理就是这个求和引理,所以我们的 ( ext{DFT}) 是算的是 (mod n) 意义下的循环卷积。于是乎,当两个下标之和大于等于 (n) 的时候,他们的值的贡献会加到前面去。所以我们得做一个长度为 (2n) 的 ( ext{DFT}),才能得到我们想要的结果。
而什么又是高维 ( ext{DFT}) 呢?
考虑高维前缀和,我们可以对每个方向依次做一遍一维的前缀和。
高维 ( ext{DFT}) 就把求一维前缀和换成一维变换。
高维 ( ext{DFT}) 相当于对下标进行按位分治,然后分层进行 ( ext{DFT}) 。就是说,枚举当前的位,然后对下标除了这一位之外都相同的每组数做一遍 ( ext{DFT}) 就好了。
而异或 ( ext{FWT}) 就是做了一个每维长度为 (2) 的高维 ( ext{DFT}) 。
为什么呢?
因为异或运算和 (mod 2) 意义下的加法等价,每一维本质上就是一个 ( ext{DFT}) 。想想看,是不是?
虽然代码和 ( ext{FFT}) 几乎一模一样,但是原理是不同的。我们刚才说 ( ext{FFT}) 是个算法,它其实是在用分治法快速算 ( ext{DFT}) 。而 ( ext{FWT}) 实际上每维是暴力算出来的 ( ext{DFT}) 。
长度为 (2) 的 ( ext{DFT}) ,需要 (2) 次单位根。不就是 (-1) 嘛!
然后
然而我们只有两个数,所以
这样也能解释为什么回来的时候要除 (2) 。因为是在做 ( ext{IDFT}) 。
以下所有涉及到的 ( ext{DFT}) 都是高维 ( ext{DFT}) ,但是实现方法是一样的,你如果能构造出一维上的 ( ext{DFT}) ,那么你只需要按位分治,然后去掉这一位进行分组,将去掉这一位之后下标一样的数分到同一组里,然后对每组分别做一维的 ( ext{DFT}) 就好了。
回到这个题上。
Subtask5
给了一个奇怪的 (A) 。通过人类智慧稍微构造一下,我们就能找到这样一个 ( ext{DFT}) :
然后继续做就好辣!
预计得分: (43) 。
算法五
你没必要知道什么是循环群。你只需要知道它和模 (m) 意义下的加法是同构的就好了。
不过我们还是得需要知道什么是阶和生成元。
幂的定义题面里给了。
这个阶和数论里的阶是差不多的,若 (i^{j+1} = i) 即 $i^j = epsilon $ ,那么 (i) 的阶就是 (j) 。记为 (ord(i)) 。
生成元和原根差不多,就是一个数 (i) 它的 ([0,ord(i))) 次幂能遍历所有元素。其实就是那个阶最大的。
然后 (mod m) 意义下的加法的幂其实就是乘法,单位元 (epsilon) 是 (0),生成元是所有与 (m) 互质的数,例如 (1) ,它的阶为 (m) 。
然后我们暴力求出这个 ([0,m)) 里每个数在 (A) 下的阶,里面肯定有个元素的阶等于 (m) 。我们把它当成 (1) ,然后让它自己与自己运算,得出其他的数的映射关系。然后我们就把下标运算转换成了 (mod m) 的加法,直接上 ( ext{DFT}) 然后转换回来就好辣!
预计得分: (53) 。
算法六
后面的部分分我不会。就算看懂了题解我也不会实现。
LCA
的标程密密麻麻的看不懂,论文里用的伪代码,根本无从下手。
(P.S. LCA
的提交之所以在 LOJ
少女附中的榜里是最慢的是因为他交的其实是个暴力,而且还用 std::cin
std::cout
读写 (5e5) 级别的数据。)
好在这个题暴力可过。于是我们就能愉快的水过去了。
( ext{DFT}) 本质上是一个“线性变换”。
什么意思?
相当于你给那个数列乘了一个矩阵,每个数只会变成一堆其他的数再乘上某个系数后的和。
我们找到那个矩阵不就行了?
我们多项式乘法的 ( ext{DFT}) 和 ( ext{FWT}) 的 ( ext{DFT}) 的矩阵是单位根的范德蒙德矩阵。
于是我们大胆猜想不用证明我们这个矩阵里放的也是单位根的几次幂。
设这个矩阵为 (T) 。
变换就能写成
我们化一化式子,看看它需要满足什么性质。
于是 (T) 必定每一行内都满足 (T_{r(p circ q)} = T_{rp}T_{rq})。
然后根据题面里说的“循环率”,元素 (i) 的 (ord(i)+1) 次幂得是 (i) 自己。于是 (T_{rq}^{ord(q)+1} = T_{rq}) 。
所以根据这个性质我们 (m^m) 枚举,在每一个位置 (i) 填上 (ord(i)) 次单位根的 ([0,ord(i)]) 次幂就好了。然后填一个数检查一下是否当前行满足 (T_{r(p circ q)} = T_{rp}T_{rq}),填满一行之后存到矩阵里。
搜出来这个矩阵之后,给它求一个逆,就是逆变换的矩阵辣!正确性显然。
啥?你不会矩阵求逆?出门右转洛谷模板区。挺好写的。就一个高斯消元。
然而会有一个问题,如果求出来的矩阵没有逆怎么办?
但是根据 LCA
的证明,好像搜出来的矩阵各行必定都线性无关。就是一定有逆。
搜出来矩阵和逆矩阵直接 ( ext{(I)DFT}) 就行了。随便搜搜就能过。跑的挺快的。
代码超级好写。
顺便提一下,正解使用的是 Subtask7
和 Subtask8
的方法科学构造出那个变换矩阵。想学的可以学一学。代码不长。
预计得分:(100)。
放一个部分分比较全的代码吧,可以参考一下。
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <vector>
inline int read() {
register int ret, cc;
while (!isdigit(cc = getchar())){}ret = cc-48;
while ( isdigit(cc = getchar())) ret = cc-48+ret*10;
return ret;
}
const int MOD = 232792561;
const int G = 71;
inline int add(int a, int b) { return (a += b) >= MOD ? a -= MOD : a; }
inline int mul(int a, int b) { return 1ll * a * b % MOD; }
inline int qpow(int a, int p) {
int ret = 1;
for ( ; p; a = mul(a, a), p >>= 1)
if (p & 1) ret = mul(a, ret);
return ret;
}
int N, M, X, U;
long long K;
namespace SubTask1 {
const int MAXN = 30;
const int MAXU = 1200;
int A[MAXN][MAXN];
int P[MAXU];
std::vector<int> state[MAXU];
int trans[MAXU][MAXU];
int f[MAXN][MAXU];
inline std::vector<int> decode(int S) {
std::vector<int> ret(N);
for (int i = 0; i < N; ++i) ret[i] = S % M, S /= M;
return ret;
}
inline int encode(const std::vector<int>& S) {
int ret = 0;
for (int i = N - 1; i >= 0; --i) ret = ret * M + S[i];
return ret;
}
inline std::vector<int> transform(const std::vector<int>& lhs, const std::vector<int>& rhs) {
std::vector<int> ret(N);
for (int i = 0; i < N; ++i) ret[i] = A[lhs[i]][rhs[i]];
return ret;
}
void Main() {
for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j) A[i][j] = read();
for (int i = 0; i < U; ++i) P[i] = read();
for (int i = 0; i < U; ++i) state[i] = decode(i);
for (int i = 0; i < U; ++i) for (int j = 0; j < U; ++j) trans[i][j] = encode(transform(state[i], state[j]));
for (int i = 0; i < U; ++i) f[0][i] = P[i];
for (int i = 0; i < K; ++i)
for (int j = 0; j < U; ++j)
for (int k = 0; k < U; ++k)
f[i+1][trans[j][k]] = add(f[i+1][trans[j][k]], mul(f[i][j], P[k]));
for (int i = 0; i < U; ++i) printf("%d
", f[K][i]);
}
}
namespace SubTask23 {
const int MAXU = 1 << 20 | 1;
int P[MAXU];
inline void FWT_OR(int* a, int n, int opt) {
for (int i = 1; i < n; i <<= 1)
for (int j = 0, p = i << 1; j < n; j += p)
for (int k = 0; k < i; ++k)
(a[j + k + i] += (opt * a[j + k] + MOD) % MOD) %= MOD;
}
inline void FWT_AND(int* a, int n, int opt) {
for (int i = 1; i < n; i <<= 1)
for (int j = 0, p = i << 1; j < n; j += p)
for (int k = 0; k < i; ++k)
(a[j + k] += (opt * a[j + k + i] + MOD) % MOD) %= MOD;
}
inline void FWT_XOR(int* a, int n, int opt) {
for (int i = 1; i < n; i <<= 1)
for (int j = 0, p = i << 1; j < n; j += p)
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = a[j + k + i];
a[j + k] = (x + y) % MOD;
a[j + k + i] = (x - y + MOD) % MOD;
}
if (opt == -1) {
int inv = qpow(n, MOD - 2);
for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % MOD;
}
}
inline void FWT_NXOR(int* a, int n, int opt) {
FWT_XOR(a, n, opt);
if (opt == -1) std::reverse(a, a + n);
}
void Main() {
int type = 0;
for (int i = 0; i < 4; ++i) type = type << 1 | read();
for (int i = 0; i < U; ++i) P[i] = read();
if ((~K&1) && type == 9) type ^= 15;
switch (type) {
case 9: FWT_NXOR(P, U, 1); break;
case 6: FWT_XOR(P, U, 1); break;
case 7: FWT_OR(P, U, 1); break;
case 1: FWT_AND(P, U, 1); break;
}
for (int i = 0; i < U; ++i) P[i] = qpow(P[i], (K + 1) % (MOD - 1));
switch (type) {
case 9: FWT_NXOR(P, U, -1); break;
case 6: FWT_XOR(P, U, -1); break;
case 7: FWT_OR(P, U, -1); break;
case 1: FWT_AND(P, U, -1); break;
}
for (int i = 0; i < U; ++i) printf("%d
", P[i]);
}
}
int P[500010];
int A[30][30];
int W[30][30];
std::vector<int> group;
const std::vector<int> subtask4({0, 1, 2, 1, 1, 2, 2, 2, 2});
const std::vector<int> subtask5({0, 1, 0, 1, 0, 1, 0, 1, 2});
namespace SubTask4 {
void Main() {
for (int i = 1; i <= U; i *= M) {
for (int j = 0; j < U; ++j) {
int p = j % i, q = j / i;
if (q % 3) P[j] = add(P[j], P[(q-1)*i+p]);
}
}
for (int i = 0; i < U; ++i) P[i] = qpow(P[i], (K+1)%(MOD-1));
for (int i = 1; i <= U; i *= M) {
for (int j = U-1; j >= 0; --j) {
int p = j % i, q = j / i;
if (q % 3) P[j] = add(P[j], MOD-P[(q-1)*i+p]);
}
}
for (int i = 0; i < U; ++i) printf("%d
", P[i]);
}
}
namespace SubTask5 {
void Main() {
const int inv2 = (MOD+1)/2;
for (int i = 1; i < U; i *= 3) {
for (int j = 0, p = i*3; j < U; j += p) {
for (int k = 0; k < i; ++k) {
int x = P[j+k], y = P[j+k+i], z = P[j+k+2*i];
P[j+k] = z;
P[j+k+i] = add(x, add(y, z));
P[j+k+2*i] = add(x, add(MOD-y, z));
}
}
}
for (int i = 0; i < U; ++i) P[i] = qpow(P[i], (K+1)%(MOD-1));
for (int i = 1; i < U; i *= 3) {
for (int j = 0, p = i*3; j < U; j += p) {
for (int k = 0; k < i; ++k) {
int x = P[j+k], y = P[j+k+i], z = P[j+k+2*i];
P[j+k] = add(mul(inv2, add(y, z)), MOD-x);
P[j+k+i] = mul(inv2, add(y, MOD-z));
P[j+k+2*i] = x;
}
}
}
for (int i = 0; i < U; ++i) printf("%d
", P[i]);
}
}
namespace SubTask6 {
inline int calc(int a, int b) { return A[a][b]; }
int tr[10];
inline int trans(int x) {
int ret = 0;
for (int i = 1; i < U; i *= M)
ret += tr[x / i % M] * i;
return ret;
}
inline int getroot() {
int root = -1;
for (int i = 0; i < M; ++i) {
int cur = A[i][i];
int rank = 1;
while (cur != i) {
++rank, cur = A[cur][i];
}
if (rank == M) {
root = i;
break;
}
}
return root;
}
int PP[500010];
int tmp[10];
void Main() {
int root = getroot();
for (int i = 1, cur = root; i <= M; ++i, cur = A[cur][root]) tr[cur] = i % M;
for (int i = 0; i < U; ++i) PP[trans(i)] = P[i];
for (int i = 1; i < U; i *= M) {
for (int j = 0, p = i * M; j < U; j += p) {
for (int k = 0; k < i; ++k) {
for (int r = 0; r < M; ++r) tmp[r] = PP[j+k+r*i];
int w = 1, wm = W[M][1];
for (int r = 0; r < M; ++r, w = mul(w, wm)) {
int sum = 0;
for (int t = M-1; ~t; --t) sum = add(mul(sum, w), tmp[t]);
PP[j+k+r*i] = sum;
}
}
}
}
for (int i = 0; i < U; ++i) PP[i] = qpow(PP[i], (K+1) % (MOD-1));
int invm = qpow(M, MOD-2);
for (int i = 1; i < U; i *= M) {
for (int j = 0, p = i * M; j < U; j += p) {
for (int k = 0; k < i; ++k) {
for (int r = 0; r < M; ++r) tmp[r] = PP[j+k+r*i];
int w = 1, wm = qpow(W[M][1], MOD-2);
for (int r = 0; r < M; ++r, w = mul(w, wm)) {
int sum = 0;
for (int t = M-1; ~t; --t) sum = add(mul(sum, w), tmp[t]);
PP[j+k+r*i] = mul(sum, invm);
}
}
}
}
for (int i = 0; i < U; ++i) printf("%d
", PP[trans(i)]);
}
}
namespace Std_Dfs {
bool vis[30];
int ord[30];
inline void getord() {
for (int i = 0; i < M; ++i) {
int cur = i;
do
ord[i]++, cur = A[cur][i];
while (cur != i);
}
}
int T[30][30];
int R[30][30];
int cnt;
int tmp[500010];
inline bool Judge(int n) {
for (int i = 0; i <= n; ++i)
for (int j = 0; j <= n; ++j)
if (A[i][j] <= n && mul(tmp[i], tmp[j]) != tmp[A[i][j]])
return false;
return true;
}
void Dfs(int n) {
if (cnt == M) return;
if (n == M) {
bool zero = 1;
for (int i = 0; i < M; ++i) if (tmp[i]) {
zero = 0;
break;
}
if (zero) return;
for (int i = 0; i < M; ++i) T[cnt][i] = tmp[i];
cnt++;
return;
}
for (int i = 0; i <= ord[n]; ++i){
tmp[n] = W[ord[n]][i];
if (Judge(n)) Dfs(n+1);
}
}
void Matrix_Inv() {
static int C[30][30];
memcpy(C, T, sizeof C);
for (int i = 0; i < M; ++i) R[i][i] = 1;
for (int i = 0; i < M; ++i) {
int p = i;
for (int j = i; j < M; ++j) if (C[j][i]) { p = j; break; }
if (p != i) for (int j = 0; j < M; ++j)
std::swap(C[i][j], C[p][j]), std::swap(R[i][j], R[p][j]);
for (int j = 0; j < M; ++j) if (i != j) {
int rate = mul(C[j][i], qpow(C[i][i], MOD-2));
for (int k = 0; k < M; ++k) {
C[j][k] = add(C[j][k], MOD-mul(C[i][k], rate));
R[j][k] = add(R[j][k], MOD-mul(R[i][k], rate));
}
}
}
for (int i = 0; i < M; ++i) {
int inv = qpow(C[i][i], MOD-2);
for (int j = 0; j < M; ++j) R[i][j] = mul(R[i][j], inv);
}
}
void dft(int n, int *a, int C[30][30]) {
if (n == 1) return;
int b = n / M;
for (int i = 0; i < M; ++i) dft(b, a + i * b, C);
for (int i = 0; i < n; ++i) tmp[i] = 0;
for (int i = 0; i < M; ++i)
for (int j = 0; j < M; ++j)
for (int k = 0; k < b; ++k)
tmp[i * b + k] = add(tmp[i * b + k], mul(a[j * b + k], C[i][j]));
for (int i = 0; i < n; ++i) a[i] = tmp[i];
}
void Main() {
getord();
Dfs(0);
Matrix_Inv();
dft(U, P, T);
for (int i = 0; i < U; ++i) P[i] = qpow(P[i], (K+1)%(MOD-1));
dft(U, P, R);
for (int i = 0; i < U; ++i) printf("%d
", P[i]);
}
}
int main() {
#ifdef ARK
freopen("F.6.0.in", "r", stdin);
freopen("test.out", "w", stdout);
#endif
for (int i = 1; i <= 22; ++i) {
W[i][0] = 1;
int rt = qpow(G, (MOD-1) / i);
for (int j = 1; j < i; ++j) W[i][j] = mul(W[i][j-1], rt);
}
N = read(), M = read(), U = qpow(M, N), scanf("%lld", &K);
for (int i = 0; i < M; ++i)
for (int j = 0; j < M; ++j) A[i][j] = read();
for (int i = 0; i < U; ++i) P[i] = read();
Std_Dfs::Main();
}
总结
你过了这个题之后,抽象代数那点基础知识应该就能明白一点点了,就能看 LCA
的论文了。
然后你就能彻底理解 ( ext{DFT}) 的本质和高位前缀和的思想,不只是局限于背板子了。
同时你能达到理性颓废的巅峰,体会数学的美妙和偷税的乐趣。