@description@
给定一个值域在 [0, 2^N) 的随机数生成器,给定参数 A[0...2^N-1]。
该生成器有 (frac{A_i}{sum A}) 的概率生成 i,每次生成都是独立的。
现在有一个 X,初始为 0。每次操作生成一个随机数 v 并将 X 异或 v。
对于每一个 i ∈ [0, 2^N),求期望多少次操作 X 第一次等于 i。
原题题面。
@solution@
不难想到期望 dp。定义 dp[i] 表示到达 i 的期望次数,则:
其中 (p[i] = frac{A_i}{sum A})。
朴素做法是高斯消元。显然过不了。
对于高斯消元的常规优化是利用转移的图结构(比如 DAG,链或者树),但是这个题转移的图是完全图,做不到。
怎么办?观察转移式的结构,发现它其实是异或卷积。于是我们尝试走生成函数那一套。
如果用生成函数的记法,又可以将其记作 (dpoplus P + I = dp + k imes T),其中 (I[i] = 1, T[i] = [i = 0]),(k) 是一个未知数。
注意当 n = 0 卷积是不成立的,所以需要在末尾填上一项 (k imes T)。
变一下形得到 (dpoplus (T - P) = I - k imes T),两边同时进行 fwt 得到 (dp' imes (T - P)' = I' - k imes T')。
注意到 ((T - P)') 的第 0 项始终为 0(根据 fwt 的定义可知),故 (I' - k imes T') 的第 0 项也为 0,由此可以解出 k。
但是这样一来我们又不知道 (dp'[0]) 的值为多少,再次设未知数为 q。进行逆变换时把未知数代进去一起运算就可以了。
然后 (dp) 数列就可以表示成含 q 的一次函数,而根据 (dp[0] = 0) 可以反解出 q,于是 (dp) 数列就解出来了。
@accepted code@
#include <cstdio>
const int MOD = 998244353;
const int INV2 = (MOD + 1) >> 1;
int add(int x, int y) {return (x + y >= MOD ? x + y - MOD : x + y);}
int sub(int x, int y) {return (x - y < 0 ? x - y + MOD : x - y);}
int mul(int x, int y) {return 1LL*x*y%MOD;}
int pow_mod(int b, int p) {
int ret = 1;
for(int i=p;i;i>>=1,b=mul(b,b))
if( i & 1 ) ret = mul(ret,b);
return ret;
}
struct node{
int k, b;
node() : k(0), b(0) {}
node(int _k, int _b) : k(_k), b(_b) {}
int get(int x) {return add(mul(k, x), b);}
friend node operator + (node a, node b) {
return node(add(a.k, b.k), add(a.b, b.b));
}
friend node operator - (node a, node b) {
return node(sub(a.k, b.k), sub(a.b, b.b));
}
friend node operator * (node a, int k) {
return node(mul(a.k, k), mul(a.b, k));
}
friend node operator / (node a, int k) {
return a * pow_mod(k, MOD - 2);
}
};
void fwt(node *A, int m, int type) {
int n = (1 << m), f = (type == 1 ? 1 : INV2);
for(int i=1;i<=m;i++) {
int s = (1 << i), t = (s >> 1);
for(int j=0;j<n;j+=s)
for(int k=0;k<t;k++) {
node x = A[j+k], y = A[j+k+t];
A[j+k] = (x + y)*f, A[j+k+t] = (x - y)*f;
}
}
}
node A[1<<18], B[1<<18], C[1<<18], f[1<<18];
int main() {
int N, M, S = 0; scanf("%d", &N), M = (1 << N);
for(int i=0;i<M;i++) scanf("%d", &A[i].b), S = add(S, A[i].b);
S = pow_mod(S, MOD - 2);
for(int i=0;i<M;i++) A[i].b = sub(i == 0 ? 1 : 0, mul(A[i].b, S));
for(int i=0;i<M;i++) B[i].b = 1;
C[0].b = MOD - 1;
fwt(A, N, 1), fwt(B, N, 1), fwt(C, N, 1);
int tmp = mul(B[0].b, pow_mod(C[0].b, MOD-2));
for(int i=1;i<M;i++)
f[i] = (B[i] - C[i]*tmp) / A[i].b;
f[0].k = 1; fwt(f, N, -1);
int x = sub(0, mul(pow_mod(f[0].k, MOD-2), f[0].b));
for(int i=0;i<M;i++) printf("%d
", f[i].get(x));
}
@details@
感觉我的做法很像是乱搞。。。不过我也不大清楚官方正解是啥子。。。