Sol
题目需要求访问完所有点(假设有 (n) 个)的期望步数,也就是要求
[E(max{x_1,x_2,cdots,x_n})
]
这里 (x_1)、(x_2)、(cdots)、(x_n) 表示第一次到达对应编号的关键点的时间。
根据 Min-Max
容斥,有
[E(max{S})=E(max{x_1,x_2,cdots,x_n})=Eleft(sum_{S'subseteq S}(-1)^{|S'|-1}min{S'}
ight)=sum_{S'subseteq S}(-1)^{|S'|+1}E(min{S'})(S'
ephi)
]
最后一步等号根据期望的线性性。
对于某一个集合 (S),设 (f_u) 表示从节点 (u) 开始,第一次到达 任意 给定点期望步数。根据期望 DP
:当 (uin S) 时
[f_u=0
]
当 (u otin S) 时
[f_u=frac1{deg_u}left(f_{fa_u}+sum_{vin son_u}f_v
ight)+1
]
直接高斯消元复杂度很劣。注意到这是一棵树。
根据线性代数的知识,使用待定系数法。令 (f_u=k_uf_{fa_u}+b_u),转化下原式
[egin{aligned}
&f_u=frac1{deg_u}left(f_{fa_u}+sum_{vin son_u}k_vf_u+b_v
ight)+1\
&Rightarrow deg_uf_u=f_{fa_u}+f_usum_{vin son_u}k_v+sum_{vin son_u}b_v+deg_u\
&Rightarrow f_uleft(deg_u-sum_{vin son_u}k_v
ight)=f_{fa_u}+sum_{vin son_u}b_v+deg_u\
&Rightarrow f_u=frac1{deg_u-sum_{vin son_u}k_v}f_{fa_u}+frac{deg_u+sum_{vin son_u}b_v}{deg_u-sum_{vin son_u}k_v}
end{aligned}
]
因此有
[egin{cases}
egin{aligned}
k_u&=frac1{deg_u-sum_{vin son_u}k_v}\
b_u&=frac{deg_u+sum_{vin son_u}b_v}{deg_u-sum_{vin son_u}k_v}
end{aligned}
end{cases}
]
当 (u in S) 时,(k_u=b_u=0)。
可以使用 DP
求系数。由于本题的特殊性((n) 很小,或者题目有保证),不存在分母为模数的倍数的情况。故求出系数,复杂度为 (mathcal O(nlog P)),注意要求逆元。总共要求 (2^n) 种情况,故复杂度为 (mathcal O(n2^nlog P))。
预处理所有询问,发现实际上是一个子集枚举,使用 FWT
优化即可。注意系数。这部分复杂度 (mathcal O(n2^n))。总复杂度为 (mathcal O(n2^nlog P))。
Code
#include <bits/stdc++.h>
const int N = 19, P = 998244353;
int n, Q, x;
struct Edge { int v, nxt; } e[N * 2];
int G[N], edges = 0;
void adde(int u, int v) {
e[edges++] = (Edge){v, G[u]}; G[u] = edges - 1;
}
int inc(int a, int b) { return (a += b) >= P ? a - P : a; }
int qpow(int a, int b) {
int t = 1;
for (; b; b >>= 1, a = 1LL * a * a % P)
if (b & 1) t = 1LL * t * a % P;
return t;
}
int k[N], b[N];
void dfs(int u, int f, int S) {
if (S & (1 << u-1)) { k[u] = b[u] = 0; return; }
int deg = f ? 1 : 0, K = 0, B = 0;
for (int i = G[u], v; ~i; i = e[i].nxt)
if (v = e[i].v, v != f)
dfs(v, u, S), deg++, K = inc(K, k[v]), B = inc(B, b[v]);
k[u] = qpow(inc(deg, P - K), P - 2);
b[u] = 1LL * k[u] * (deg + B) % P;
}
int a[1 << N];
void fwt(int *a, int n, int op) {
for (int q = 1; q < n; q <<= 1)
for (int p = 0; p < n; p += q << 1)
for (int i = 0; i < q; i++)
a[p+q+i] = inc(a[p+q+i], a[p+i]);
}
int main() {
scanf("%d%d%d", &n, &Q, &x);
memset(G, -1, sizeof G);
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
adde(u, v), adde(v, u);
}
for (int S = 1; S < 1 << n; S++) {
int cnt = 0;
dfs(x, 0, S);
for (int i = 0; i < n; i++) cnt += (S >> i) & 1;
a[S] = cnt & 1 ? b[x] : P - b[x];
}
fwt(a, 1 << n, 1);
while (Q--) {
int k, S = 0; scanf("%d", &k);
for (int i = 1; i <= k; i++)
scanf("%d", &x), S |= 1 << x-1;
printf("%d
", a[S]);
}
return 0;
}