Problem
题意:一棵 (n) 个结点的树,从点 (x) 出发,每次等概率随机选择一条与所在点相邻的边走过去,询问走完一个集合 (S)的期望时间,多组询问
(nleq 18,Qleq 5000)
Solution
首先来个(min-max)容斥
一下是看错题时想的
然后预处理从每个点开始的到达每个点的所有集合的期望,(O(n^22^n))卡常可过
- 若是这样,前20pts可以搞出来了:对于每次询问在线处理dp数组,利用最值容斥搞事情
- 30pts的部分是条链,可以对于每个部分做一次集合前缀后缀预处理
- 40pts的部分每次只询问一个点,明显可以预处理
- 事实上70pts的部分最简单,直接上容斥即可
接下来是100pts……
目前最暴力复杂度为(O(n^22^n+sum 2^{k_i})),若数据随机,期望状态下每次仅需计算(sum_{i=1}^nfrac {2^iinom ni}{2^n}approx 1478)次,即总共大约需要计算(18^2 imes 2^{18}+5000 imes 1478=92324656leq 10^8)次,理论可过,但出题人十有八九将其卡掉了(网上说没卡)
上头是看看错题的情况下想的,实际上所有询问中起点只可能有一个
重新理一遍思路:由于要(min-max)反演,即到达集合(S)中最后到达点的时间可以转化为(2^{|S|})个子集中最先到达点的时间进行容斥,再容斥即可
考虑如何求每个集合中最先到达点的期望时间,由期望的性质可得(设(f[x])表示从(x)点出发到达第一个集合点的时间期望,(v)表示与(x)相连的节点,(d_i)表示(i)点的度数):
由于需要对所有(2^n)个集合都做dp,酱紫求一次需要高斯消元(n^3),但总复杂度(O(n^32^n)approx 1.5 imes 10^9)一定接受不了
发现目前为止还没有利用最重要的一个性质:这是一棵树
要快速求解(f_x)数组,就需要一定的优化,由于这张图是一棵树,即每两点之间的路径是唯一的,可设(f_x=Af_{t}+B)(其中(t)为(x)的父亲)
利用上边的式子列出方程(设(c)为(x)的儿子):
由于儿子加父亲节点一共(d_x)个,即
我们要解出(A,B),所以需要将(f_x=Af_t+B)代入式子
解得:
再加上集合(S)内的点(x)满足(A_x=B_x=0)
然后对于每个集合 (S) 就可以 (O(n)) 地求出从任意点出发到达第一个集合内点的期望时间
这样询问就可以愉悦地 (mathrm{min-max}) 容斥了
(O(1)) 查询的话只需要高维前缀和一下即可,于是乎总时间复杂度为 (O(n2^n+Q))
upd:好像还要算上求逆元,复杂度 (O(n2^nlog p+Q)),算出来大概 (1.5e8),但由于枚举集合后一旦走到集合就递回,再算上快速幂的小常数,完全可过(复杂度跑不满,跑得最满的情况是菊花图,但即便是菊花图,常数最大也就 (frac 12))
Code
#include <cstdio>
#include <cctype>
inline void read(int&x){
char c11=getchar();x=0;while(!isdigit(c11))c11=getchar();
while(isdigit(c11))x=x*10+c11-'0',c11=getchar();
}
const int N=19,M=1<<18,p=998244353;
struct Edge{int v,nxt;}a[N*N];
int k[N],b[N],deg[N],head[N];
int f[M],bit[M],n,Q,st,_;
inline int qpow(int A,int B){
int res(1);
while(B){
if(B&1)res=1ll*res*A%p;
A=1ll*A*A%p,B>>=1;
}return res;
}
inline void pls(int&A,int B){A=A+B<p?A+B:A+B-p;}
inline void dec(int&A,int B){A=A-B<0?A-B+p:A-B;}
void dfs(int x,int las,int lim){
if(lim&(1<<x-1)){k[x]=b[x]=0;return ;}
k[x]=b[x]=deg[x];
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v!=las){
dfs(a[i].v,x,lim);
dec(k[x],k[a[i].v]);
pls(b[x],b[a[i].v]);
}
k[x]=qpow(k[x],p-2);
b[x]=1ll*b[x]*k[x]%p;
}
int main(){
read(n),read(Q);read(st);int lim=1<<n;
for(int i=1,x,y;i<n;++i){
read(x),read(y),++deg[x],++deg[y];
a[++_].v=y,a[_].nxt=head[x],head[x]=_;
a[++_].v=x,a[_].nxt=head[y],head[y]=_;
}
for(int S=1;S<lim;++S){
bit[S]=bit[S>>1]+(S&1);
dfs(st,0,S);
f[S]=(bit[S]&1?b[st]:p-b[st]);
}
for(int i=1;i<lim;i<<=1)
for(int j=0;j<lim;++j)
if(i&j)pls(f[j],f[i^j]);
int t,x,s;
while(Q--){
read(t),s=0;
while(t--)read(x),s|=1<<x-1;
printf("%d
",f[s]);
}
return 0;
}