「UOJ351」新年的叶子
题目描述
有一棵大小为 (n) 的树,每次随机将一个叶子染黑,可以重复染,问期望染多少次后树的直径会缩小。(1 leq n leq 5 imes 10^5)
解题思路 :
首先要利用一个经典的结论,树的所有直径的中心为同一个点/边。不妨给每条边加一个虚拟点,这样整颗树的直径就只会交于同一个点了。
接下来考虑树的直径是由中心的两个儿子的两个深度为 (maxdep) 的叶子构成的,所以问题等价于将叶子根据中心的儿子分成若干个集合,对于所有染色方案求染到只剩一个集合没有被完全染黑的期望步数之和,这个东西再除以一个方案数就是答案。
这个东西好难求啊,推了半天式子还是不太会,最后只能看题解辅助推导 ( ext{qwq}) 。首先把随机染黑一个叶子转化为随机染黑一个没有染黑的叶子,那么此时的期望步数就是总叶子数除以没有被染黑的叶子数。然后枚举一下最后一个被染黑的集合是哪个集合,在第一次到达只有它没被完全染黑这个状态之前其被染黑的叶子总数是多少,不妨分别设他们为 (i, k) 。那么首先先选出 (k) 个叶子染黑,其在操作序列中有 (n - sz_i+k-1) 个位置可以放,(-1) 是因为不能放最后一个位置,如果放了最后一个位置其就不是第一次到达的状态就算重了,然后把其它的操作的排列数乘上就是方案数。而步数就是 (sum_{i=sz_i-k+1}^{n} frac{m}{i}) ,其中 (n) 是集合总大小,(m) 是总的叶子数量。 由于我的写法常数太大,在 UOJ 上的统计榜成功吃鸡。
/*program by mangoyang*/
#pragma GCC optimize("Ofast","inline","-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
#define pii pair<int, int>
#define fi first
#define se second
const int N = 1000005, mod = 998244353;
vector<int> g[N];
int dep[N], f[N], js[N], inv[N], buf[N], n, m, all, mxdep, rt;
ll h[N], ans;
namespace prework{
queue<pii> q; int vis[N], pre[N];
inline pii bfs(int s){
memset(vis, 0, sizeof(vis));
q.push(make_pair(s, 0)), vis[s] = 1; pii res;
while(!q.empty()){
pii now = q.front(); q.pop();
int x = now.fi, dis = now.se; res = now;
for(int i = 0; i < g[x].size(); i++){
int v = g[x][i];
if(vis[v]) continue;
vis[v] = 1, pre[v] = x;
q.push(make_pair(v, dis + 1));
}
}
return res;
}
inline void dfs(int u, int fa){
dep[u] = dep[fa] + 1, f[u] = dep[u], buf[u] = 1;
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];
if(v == fa) continue;
dfs(v, u);
if(f[v] > f[u]) buf[u] = buf[v], f[u] = f[v];
else if(f[v] == f[u]) buf[u] += buf[v];
}
if(g[u].size() == 1 && fa) m++;
}
inline void realmain(){
pii s1 = bfs(1), s2 = bfs(s1.fi);
int dis = s2.se / 2; rt = s2.fi;
for(int i = 1; i <= dis; i++) rt = pre[rt];
dfs(rt, 0);
}
}
inline void up(ll &x, int y){ (x += y) %= mod; }
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline int C(int x, int y){
return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod;
}
inline int calc(int x, int y){
ll res = 1ll * C(y, x) * C(all-(y-x)-1, x) % mod;
(res *= (1ll * js[y-x] * js[all-y] % mod)) %= mod;
(res *= (1ll * js[x] * (h[all] - h[y-x]) % mod)) %= mod;
return res;
}
int main(){
js[0] = inv[0] = 1;
for(int i = 1; i < N; i++)
js[i] = 1ll * js[i-1] * i % mod, inv[i] = Pow(js[i], mod - 2);
read(n); int size = n;
for(int i = 1, x, y; i < n; i++){
read(x), read(y), ++size;
g[x].push_back(size), g[size].push_back(x);
g[y].push_back(size), g[size].push_back(y);
}
prework::realmain();
for(int i = 1; i <= m; i++)
up(h[i], 1ll * (h[i-1] + 1ll * m * Pow(i, mod - 2)) % mod);
for(int i = 0; i < g[rt].size(); i++) mxdep = max(mxdep, f[g[rt][i]]);
for(int i = 0; i < g[rt].size(); i++){
if(f[g[rt][i]] == mxdep) all += buf[g[rt][i]];
else buf[g[rt][i]] = 0;
}
for(int i = 0; i < g[rt].size(); i++)
for(int j = 0; j < buf[g[rt][i]]; j++) up(ans, calc(j, buf[g[rt][i]]));
cout << (1ll * ans * inv[all] % mod + mod) % mod;
return 0;
}