@description@
给定一棵 n 个结点的树,你从点 x 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 Q 次询问,每次询问给定一个集合 S,求如果从 x 出发一直随机游走,直到点集 S 中所有点都至少经过一次的话,期望游走几步。
特别地,点 x(即起点)视为一开始就被经过了一次。
答案对 998244353 取模。
input
第一行三个正整数 n,Q,x。1≤n≤18,1≤Q≤5000。
接下来 n-1 行,每行两个正整数 (u,v) 描述一条树边。
接下来 Q 行,每行第一个数 k 表示集合大小,接下来 k 个互不相同的数表示集合 S。1≤k≤n。
output
输出 Q 行,每行一个非负整数表示答案。
sample input
3 5 1
1 2
2 3
1 1
1 3
2 2 3
3 1 2 3
2 1 2
sample output
0
4
4
4
1
sample explain
样例给的树是一条长度为 3 的链,且起点是链的一端,所以答案只跟最远的那个点有关。
当最远的点是 2 时,显然只需要 1 步就行了。
当最远的点是 3 时,通过计算可得期望 4 步到达。
@solution@
题目一整个暗示 min - max 容斥。
根据 min - max 容斥公式,我们有:
至于为什么,讨论每个值对答案的贡献即可。显然在期望意义下也成立(因为期望就是个线性组合)。
我们把 max 视作最后一个到达的集合内的点的步数,那么 min 就是最先到达集合内的点的步数。
怎么去求解 min(S) 呢?先将原树以 x 为根转换为有根树。
在给定 S 的前提下,我们不妨设 (f(i)) 表示从结点 i 出发到达集合 S 内的点的期望步数,进行 dp。
显然如果 (iin S),有 (f(i) = 0)。
否则,我们有 (f(i) = frac{sum_{cin child(i)}f(c)+f(fa(i))}{deg(i)} + 1)。
可以看到这是一个有后效性的 dp。但是如果直接高斯消元的话时间复杂度是会爆炸的。
我们可不可以利用一下树这个性质呢?
可以发现对于叶子来说,这个点的方程只含它和它父亲。因此,我们就可以用它父亲的 f 值表示这个点的 f 值。
然后呢?我们不妨再看,假如一个结点的儿子全是叶子结点,我们已经知道用这个点的 f 值可以表示它所有的儿子的 f 值。因此,这个结点的方程又只剩下它和它父亲了。
因此我们发现,结点 i 所对应的 f(i),总可以找到相应的 ai 与 bi,使得 (f(i)=a_i*f(fa(i))+b_i)。
当迭代到根的时候,根是没有父亲的,所以 (f(x) = b_x)。
然后再把这个值下传给儿子,就可以求出这个树所有点的 dp 值了。
那么,假若已经知道所有儿子的 a 和 b 的值,就可以代入状态转移式,通过解方程求出该点的 a 和 b 的值。具体求解过程留作习题。
解出来是:
然后就可以愉快地 dp 了。因为我们只需要从根出发的 f 值所以可以不用下传。
min - max 容斥的时候有的人是询问一次容斥一次 O(2^n*Q),有的人是预处理枚举子集 O(3^n*n) 然后 O(1) 回答询问(最关键的是竟然都能过???)。
但其实……完全用不着啊。我们完全可以 O(2^n*n) 预处理然后 O(1) 回答询问啊。
【好像存在不用 min-max 容斥的方法?不过看起来比较复杂……】
@accepted code@
#include<cstdio>
const int MAXN = 18;
const int MOD = 998244353;
int pow_mod(int b, int p) {
int ret = 1;
while( p ) {
if( p & 1 ) ret = 1LL*ret*b%MOD;
b = 1LL*b*b%MOD;
p >>= 1;
}
return ret;
}
struct edge{
int to; edge *nxt;
}edges[2*MAXN + 5], *adj[MAXN + 5], *ecnt=&edges[0];
void addedge(int u, int v) {
edge *p = (++ecnt);
p->to = v, p->nxt = adj[u], adj[u] = p;
p = (++ecnt);
p->to = u, p->nxt = adj[v], adj[v] = p;
}
int deg[MAXN + 5], a[MAXN + 5], b[MAXN + 5];
void dfs(int rt, int fa, int s) {
if( (1<<rt) & s ) {
a[rt] = b[rt] = 0;
return ;
}
a[rt] = b[rt] = deg[rt];
for(edge *p=adj[rt];p;p=p->nxt) {
if( p->to == fa ) continue;
dfs(p->to, rt, s);
a[rt] = (a[rt] + MOD - a[p->to])%MOD;
b[rt] = (b[rt] + b[p->to])%MOD;
}
a[rt] = pow_mod(a[rt], MOD-2);
b[rt] = 1LL*b[rt]*a[rt]%MOD;
}
int f[1<<MAXN], k[1<<MAXN];
int main() {
int n, Q, x; scanf("%d%d%d", &n, &Q, &x); x--;
for(int i=1;i<n;i++) {
int u, v; scanf("%d%d", &u, &v); u--, v--;
addedge(u, v); deg[u]++, deg[v]++;
}
int t = (1<<n);
for(int s=0;s<t;s++) {
dfs(x, -1, s), f[s] = b[x];
/*
printf("%d :
", s);
for(int i=0;i<n;i++)
printf("%d %d
", a[i], b[i]);
puts("");
*/
}
k[0] = -1;
for(int i=0;i<n;i++)
for(int s=0;s<t;s++)
if( (1<<i) & s ) f[s] = (f[s] + MOD - f[s^(1<<i)])%MOD;
for(int s=1;s<t;s++)
k[s] = k[s>>1]*((s&1) ? -1 : 1), f[s] = (MOD + k[s]*f[s])%MOD;
for(int i=1;i<=Q;i++) {
int k, x, s = 0; scanf("%d", &k);
for(int j=0;j<k;j++) {
scanf("%d", &x);
s = s | (1<<(x-1));
}
printf("%d
", f[s]);
}
}
@details@
一开始好像是因为减来减去一不小心减成负数了然后直接输出了负数……
我怎么这么傻啊……