题目大意:
给定一棵 n 个结点的树,你从点 x 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 Q 次询问,每次询问给定一个集合 S,求如果从 x 出发一直随机游走,直到点集 S 中所有点都至少经过一次的话,期望游走几步。
特别地,点 x(即起点)视为一开始就被经过了一次。
答案对 998244353 取模。
思路:
看到所有点都经过一次直接上min-max反演。
然后我们要求每个集合的min,即标记关键点之后求期望意义下第一次到关键点要走的步数,上dp的话发现转移式子有环,树上高斯消元即可。
为了方便我们可以把所有答案都处理出来,但是这样复杂度为(3^n),常数大可能通过不了,发现我们需要求的就是每个状态的子集和,上FWT和FMT都可以将时间复杂度优化到(2^n imes n)。
/*=======================================
* Author : ylsoi
* Time : 2019.2.10
* Problem : loj2542
* E-mail : ylsoi@foxmail.com
* ====================================*/
#include<bits/stdc++.h>
#define REP(i,a,b) for(register int i=a,i##_end_=b;i<=i##_end_;++i)
#define DREP(i,a,b) for(register int i=a,i##_end_=b;i>=i##_end_;--i)
#define debug(x) cout<<#x<<"="<<x<<" "
#define fi first
#define se second
#define mk make_pair
#define pb push_back
typedef long long ll;
using namespace std;
void File(){
freopen("loj2542.in","r",stdin);
freopen("loj2542.out","w",stdout);
}
template<typename T>void read(T &_){
_=0; T f=1; char c=getchar();
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())_=(_<<1)+(_<<3)+(c^'0');
_*=f;
}
const int maxn=18+5;
const int maxw=(1<<19)+10;
const int mod=998244353;
int n,q,rt,all,cnt[maxw];
vector<int>G[maxn];
ll d[maxn],a[maxn],b[maxn],mx[maxw];
bool c[maxn];
inline ll qpow(register ll x,register ll y){
ll ret=1; x%=mod;
while(y){
if(y&1)ret=ret*x%mod;
x=x*x%mod;
y>>=1;
}
return ret;
}
inline ll inv(register ll x){return qpow(x,mod-2);}
inline void dfs(register int u,register int fh){
ll sa=0,sb=0;
REP(i,0,G[u].size()-1){
int v=G[u][i];
if(v==fh)continue;
dfs(v,u);
sa=(sa+a[v])%mod;
sb=(sb+b[v])%mod;
}
if(c[u])a[u]=b[u]=0;
else{
a[u]=inv(d[u]-sa);
b[u]=(sb+d[u])*a[u]%mod;
}
}
int main(){
File();
read(n),read(q),read(rt);
all=(1<<n)-1;
int u,v;
REP(i,1,n-1){
read(u),read(v);
G[u].pb(v),++d[u];
G[v].pb(u),++d[v];
}
c[2]=1;
dfs(rt,0);
REP(S,1,all)cnt[S]=__builtin_popcount(S)%2 ? 1 : -1;
REP(S,1,all){
REP(i,1,n)if((1<<(i-1))&S)c[i]=1;
else c[i]=0;
dfs(rt,0);
mx[S]=b[rt]*cnt[S];
}
for(int len=1;len<=all;len<<=1)
for(int L=0;L<=all;L+=len<<1)
REP(i,L,L+len-1)
mx[i+len]=(mx[i+len]+mx[i])%mod;
int S=0,sz,x;
REP(i,1,q){
S=0;
read(sz);
REP(j,1,sz)read(x),S^=1<<(x-1);
printf("%lld
",(mx[S]+mod)%mod);
}
return 0;
}