正题
题目链接:https://uoj.ac/problem/351
题目大意
给出\(n\)个点的一棵树,开始所有点都是白色,每次随机点黑一个叶子(可以重复点),求期望多少次能使得白色点构成的图直径发生变化。
答案对\(998244353\)取模
\(1\leq n\leq 5\times 10^5\)
解题思路
考虑什么时候会直径会产生变化。
假设直径的长度\(L\)为偶数,那么所有的直径都有一个共同的中心点,设为\(x\)。此时我们需要在\(x\)的两棵子树中各自找到两个深度为\(\frac L 2\)的叶子,那么就可以组成一条直径。
换句话说,把所有深度为\(\frac L 2\)叶子取出来,然后把它们按照在那个根的子树中分成若干个集合。然后当我们染色到只有一个集合没有全部染色的时候就结束了。
那么现在问题变成给出若干个集合和一些集合外的点,每次染一个点,求期望多少次能够染成只有一个集合没有全部染色。
考虑总共有\(n\)个点,有\(i\)个已经染色了,那么染色下任意一个的概率就是\(\frac{i}{n}\),期望就是\(\frac{n}{i}\)。
预处理\(f_i=\sum_{j=1}^i\frac{n}{j}\),然后我们可以考虑把集合中的点排列然后按顺序染,最后除上方案就好了。
假设所有集合中总共有\(m\)个点,目前枚举到的集合有\(k\)个点,然后染到这个集合剩下\(p\)个点的时候其他集合都染完了,那么期望就是
(中间的减法是为了保证最后剩下的\(p\)个点前面一定是一个不是枚举集合中的点,然后\(f_m-f_p\)是因为我们假设不在集合中的点已经染色了,那么剩下需要染\(m-p\)个)。
至于直径长度是奇数的情况,那么有两个中心点,也就是有一条中心边,分成两个集合按照上面的搞就好了。
时间复杂度:\(O(n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=5e5+10,P=998244353;
struct node{
ll to,next;
}a[N<<1];
ll n,m,cnt,mxdis,root,tot,ans;
ll ls[N],v[N],pre[N],fac[N],inv[N],f[N];
void addl(ll x,ll y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
void findL(ll x,ll fa,ll dis){
if(dis>mxdis)mxdis=dis,root=x;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa)continue;
pre[y]=x;findL(y,x,dis+1);
}
return;
}
void markL(ll x,ll fa,ll dis,ll k){
if(dis==mxdis/2&&(!a[ls[x]].next))v[k]++;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa)continue;
markL(y,x,dis+1,k);
}
return;
}
ll C(ll n,ll m)
{return fac[n]*inv[m]%P*inv[n-m]%P;}
signed main()
{
scanf("%lld",&n);
for(ll i=1,x,y;i<n;i++){
scanf("%lld%lld",&x,&y);
addl(x,y);addl(y,x);
}
ll k=0;
for(ll i=1;i<=n;i++)k+=!(a[ls[i]].next);
inv[0]=inv[1]=fac[0]=1;
for(ll i=2;i<N;i++)inv[i]=P-inv[P%i]*(P/i)%P;
for(ll i=1;i<N;i++)f[i]=(f[i-1]+k*inv[i]%P)%P;
for(ll i=1;i<N;i++)fac[i]=fac[i-1]*i%P,inv[i]=inv[i-1]*inv[i]%P;
findL(1,0,0);mxdis=0;
findL(root,0,0);
if(mxdis&1){
ll x=root;
for(ll i=1;i<=mxdis/2;i++)x=pre[x];
ll y=pre[x];markL(x,y,0,1);markL(y,x,0,2);
cnt=2;
}
else{
ll x=root;
for(ll i=1;i<=mxdis/2;i++)x=pre[x];
for(ll i=ls[x];i;i=a[i].next)
cnt++,markL(a[i].to,x,1,cnt);
}
for(ll i=1;i<=cnt;i++)m+=v[i];
for(ll i=1;i<=cnt;i++)
for(ll j=1;j<=v[i];j++){
ll w=(f[m]-f[j]+P)%P;
w=(fac[m-j]-(v[i]-j)*fac[m-j-1]%P+P)%P*fac[j]%P*w%P;
(ans+=w*inv[m]%P*C(v[i],j)%P)%=P;
}
printf("%lld\n",ans);
return 0;
}