题目:https://loj.ac/problem/2542
可以最值反演。注意 min 不是独立地算从根走到每个点的最小值,在点集里取 min ,而是整体来看,“从根开始走到点集中的任意一个点就停下”的期望步数。
设 f[ i ] 表示从根走到 i ,再走期望几步就能走到点集中的某个点。有 ( f[i]=frac{1}{d[i]}sumlimits_{j}(f[j]+1) ) ( j 是和 i 有边的点)
于是要“树上高斯消元”。其实就是尝试写成 ( f[i]=a[i]*f[st]+b[i] ) (st 是根)之类的形式,从而让系数的转移有方向,然后根据 ( a[st] ) 和 ( b[st] ) 算 ( f[st] ) 。
为了有方向,这里设 ( f[i]=a[i]*f[st]+b[i]*f[fa]+c[i] ) (有 ( a[i]*f[st] ) 是为了算 ( f[st] ) )
( f[i]=frac{1}{d[i]}f[fa]+frac{1}{d[i]}+frac{1}{d[i]}sumlimits_{j in child}f[j]+frac{d[i]-1}{d[i]} )
( d[i]*f[i]=f[fa]+d[i]+sumlimits_{j in child}a[j]f[st]+sumlimits_{j in child}b[j]f[i]+sumlimits_{j in child}c[j] )
( (d[i]-sumlimits_{j in child}b[j])f[i]=sumlimits_{j in child}a[j]f[st]+f[fa]+d[i]+sumlimits_{j in child}c[j] )
然后对于每个点集可以树形DP地算。如果走到了点集中的点,那么 a[cr] 、b[cr] 、c[cr] 都是 0 ,并且直接 return 即可。
最值反演的时候求子集的和可以用 fmt 算。那个 -1 的系数只要在初值的时候体现一下就行了。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=20,M=(1<<18)+5,mod=998244353; int upt(int x){if(x<0)x+=mod;if(x>=mod)x-=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,st,hd[N],xnt,to[N<<1],nxt[N<<1],dg[N],a[N],b[N],c[N]; int bin[N],f[M],ct[M]; void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;dg[x]++;} void dfs(int cr,int fa,int s) { a[cr]=b[cr]=c[cr]=0;int tp=0; if(s&bin[cr-1])return;//////return! for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs(v,cr,s); a[cr]=upt(a[cr]+a[v]); c[cr]=upt(c[cr]+c[v]); tp=upt(tp+b[v]); } tp=pw(upt(dg[cr]-tp),mod-2); a[cr]=(ll)a[cr]*tp%mod; b[cr]=tp; c[cr]=(ll)(c[cr]+dg[cr])*tp%mod; } void fmt() { for(int i=1;i<bin[n];i<<=1) for(int s=0;s<bin[n];s++) if(s&i)f[s]=upt(f[s]+f[s^i]); } int main() { n=rdn();int Q=rdn();st=rdn(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); bin[0]=1;for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1; for(int s=1;s<bin[n];s++)ct[s]=ct[s-(s&-s)]+1; for(int s=1;s<bin[n];s++) { dfs(st,0,s);f[s]=(ll)c[st]*pw(upt(1-a[st]),mod-2)%mod; if((ct[s]&1)==0)f[s]=upt(-f[s]); } fmt(); while(Q--) { n=rdn();int s=0; for(int i=1;i<=n;i++)s|=bin[rdn()-1]; printf("%d ",f[s]); } return 0; }