题目描述:
给定一棵 n 个点的树。
每次等概率选定一个联通块,将该联通内的所有点都捶一遍。再从选定的联通块中随机选取一个点,删掉该点及其连边。
反复操作,直至没有剩余点,求所有点被捶次数的期望 ×n! ,答案对 109+7 取模。
算法标签:点分治,fft
思路:
考虑一个点A会在另一个点B被选中时被锤到,当且仅当A和B路径上的点被锤到的顺序排列,B在最前,所以被锤到的概率是(B到A路径上的点数包括A,B)分之1。
用点分治和fft维护每一种路径长度路径数。
以下代码:
#include<bits/stdc++.h> #define il inline #define LL long long #define db double #define pi acos(-1) #define _(d) while(d(isdigit(ch=getchar()))) using namespace std; const int N=2e5+5,p=1e9+7; bool vis[N];int rt,d[N],ans,res[N]; int v[N],po[N],t,l,size,sz[N],mx[N],a[N]; int n,head[N],ne[N<<1],to[N<<1],maxd,cnt,jc[N],ny[N]; struct cp{ db r,i; cp(){};cp(db _r,db _i){r=_r;i=_i;} friend cp operator+(cp t1,cp t2){return cp(t1.r+t2.r,t1.i+t2.i);} friend cp operator-(cp t1,cp t2){return cp(t1.r-t2.r,t1.i-t2.i);} friend cp operator*(cp t1,cp t2){return cp(t1.r*t2.r-t1.i*t2.i,t1.i*t2.r+t2.i*t1.r);} il void clear(){r=i=0;} }b[N]; il int read(){ int x,f=1;char ch; _(!)ch=='-'?f=-1:f;x=ch^48; _()x=(x<<1)+(x<<3)+(ch^48); return f*x; } il void ins(int x,int y){ ne[++cnt]=head[x]; head[x]=cnt;to[cnt]=y; } il int mu(int x,int y){ if(x+y>=p)return x+y-p; return x+y; } il int ksm(LL a,int y){ LL b=1; while(y){ if(y&1)b=b*a%p; a=a*a%p;y>>=1; } return b; } il void dft(cp *x,int o){ for(int i=0;i<t;i++)if(i<v[i])swap(x[i],x[v[i]]); for(int i=1;i<t;i<<=1){ cp wn=cp(cos(pi/i),o*sin(pi/i)); for(int j=0;j<t;j+=i<<1){ cp w=cp(1,0); for(int k=0;k<i;k++,w=w*wn){ cp A=x[j+k],B=x[i+j+k]*w; x[j+k]=A+B;x[i+j+k]=A-B; } } } if(o<0)for(int i=0;i<t;i++)x[i].r/=t; } il void getrt(int x,int fa){ sz[x]=1;mx[x]=0; for(int i=head[x];i;i=ne[i]){ if(fa==to[i]||vis[to[i]])continue; getrt(to[i],x);sz[x]+=sz[to[i]]; if(sz[to[i]]>mx[x])mx[x]=sz[to[i]]; } if(size-sz[x]>mx[x])mx[x]=size-sz[x]; if(mx[rt]>mx[x])rt=x; } il void dfs(int x,int fa){ if(d[x]>maxd)maxd=d[x];a[d[x]]++; for(int i=head[x];i;i=ne[i]){ if(fa==to[i]||vis[to[i]])continue; d[to[i]]=d[x]+1;dfs(to[i],x); } } il void cal(int x,int vv){ t=1;l=0; while(t<=(maxd<<1))t<<=1,l++; for(int i=0;i<t;i++)v[i]=(v[i>>1]>>1)|((i&1)<<l-1); for(int i=0;i<=maxd;i++)b[i].r=a[i]; dft(b,1); for(int i=0;i<t;i++)b[i]=b[i]*b[i]; dft(b,-1); for(int i=0;i<=(maxd<<1);i++)res[i+1]=mu(res[i+1],((LL)(b[i].r+.5)*vv+p)%p); for(int i=0;i<=maxd;i++)a[i]=0;maxd=0; for(int i=0;i<t;i++)b[i].clear(); } il void solve(int x){ vis[x]=1;d[x]=0; dfs(x,0);cal(x,1); for(int i=head[x];i;i=ne[i]){ if(vis[to[i]])continue; d[to[i]]=1; dfs(to[i],0);cal(to[i],-1); } for(int i=head[x];i;i=ne[i]){ if(vis[to[i]])continue; size=sz[to[i]];rt=0; getrt(to[i],x);solve(rt); } } int main() { n=read();mx[0]=n; for(int i=1;i<n;i++){ int x=read(),y=read(); ins(x,y);ins(y,x); } jc[0]=1;for(int i=1;i<=n;i++)jc[i]=1ll*i*jc[i-1]%p; for(int i=1;i<=n;i++)ny[i]=ksm(i,p-2); t=1;l=0;while(t<=n)t<<=1,l++; size=n;getrt(1,0);solve(rt); for(int i=1;i<=n;i++)ans=mu(ans,1ll*res[i]*ny[i]%p); printf("%d ",1ll*ans*jc[n]%p); return 0; }