Description
给定一棵 (n) 个结点的树,你从点 (x) 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 (Q) 次询问,每次询问给定一个集合 (S),求如果从 (x) 出发一直随机游走,直到点集 (S) 中所有点都至少经过一次的话,期望游走几步。
特别地,点 (x)(即起点)视为一开始就被经过了一次。
答案对 $998244353 $ 取模。
Solution
考虑 min-max 容斥,问题变成求从 (x) 点出发第一次到集合 (S) 中的点的期望步数
枚举集合 (S),尝试树形DP求出
设 (f(i,S)) 表示从点 (i) 出发第一次到达集合 (S) 中的点的期望步数
则有转移 (f(i,S)=1+frac1{deg(i)}cdot f(fa[i],S)+frac1{deg(i)}sumlimits_{son}f(son,S))
根据我从没见过的树形DP常见套路,(f(i)) 一般都可以写成 (Acdot f(fa)+B) 的形式
然后推式子
[f(i,S)=1+frac1{deg(i)}cdot f(fa[i],S)+frac1{deg(i)}(sumlimits_{son}A(son)cdot f(x)+B(son))
]
[f(i,S)=frac1{deg(i)-sumlimits_{son}A(son)}cdot f(fa[i],S)+frac{deg(i)+sumlimits_{son}B(son)}{deg(x)-sumlimits_{son}A(son)}
]
得:(A(i)=frac1{deg(i)-sumlimits_{son}A(son)},B(i)=frac{deg(i)+sumlimits_{son}B(son)}{deg(x)-sumlimits_{son}A(son)})
当DP到一个 (S) 集合中的节点 (p) 时,令 (A(p)=B(p)=0),表示从这个点出发期望走 (0) 步就可以到达集合 (S) 中的点。最后的 (B(x)) 就是答案(因为 (f(fa[x])) 始终为 (0) )
Code
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
#define int long long
const int N=19;
const int mod=998244353;
int a[N],b[N],f[1<<N],cnts[1<<N];
int n,q,s,cnt,maxn,head[N],deg[N];
struct Edge{
int to,nxt;
}edge[N<<1];
void add(int x,int y){
edge[++cnt].to=y;
edge[cnt].nxt=head[x];
head[x]=cnt;
}
int ksm(int a,int b=mod-2,int ans=1){
while(b){
if(b&1) ans=ans*a%mod;
a=a*a%mod;b>>=1;
} return ans;
}
void dfs(int now,int fa,int S){
if(S>>now-1&1) return a[now]=b[now]=0,void();
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
if(to==fa) continue;
dfs(to,now,S);
(a[now]+=a[to])%=mod;(b[now]+=b[to])%=mod;
} a[now]=ksm((deg[now]-a[now]+mod)%mod);b[now]=(b[now]+deg[now])%mod*a[now]%mod;
}
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
signed main(){
n=getint(),q=getint(),s=getint();maxn=1<<n;
for(int i=1;i<n;i++){
int x=getint(),y=getint();
add(x,y),add(y,x);deg[x]++;deg[y]++;
}
for(int i=1;i<maxn;i++){
cnts[i]=cnts[i>>1]+(i&1);
memset(a,0,sizeof a),memset(b,0,sizeof b);
dfs(s,0,i);
f[i]=(cnts[i]&1?b[s]:mod-b[s]);
}
for(int i=1;i<=n;i++)
for(int j=1;j<maxn;j++)
if(j>>i-1&1) (f[j]+=f[j^(1<<i-1)])%=mod;
while(q--){
int len=getint(),x=0;
while(len--) x|=1<<getint();
printf("%lld
",f[x>>1]);
} return 0;
}