题目
在一棵树里随机点分治,求每次选点的子树大小期望。
(nle 10^5)。
解法
其实很重要的一点就是意识到这个选点方式就是随机点分治诶。
很自然地,因为没法计算子树大小,我们将贡献转移到单点上,也即每个点的贡献就是在点分树上的深度(它的祖先比它先移动且和此点在一个联通快)。
问题转化为每个点深度期望和,而每个点深度期望就是其他点是它祖先的概率和。
考虑对于点对 ((x,y)),我们使 (x) 为 (y) 祖先的概率。
而使 (x) 为 (y) 祖先的条件就是:(x) 是原树上 (x,y) 路径上的点在点分树上最小深度的点。容易发现如果还有深度更小的点(更早移动)就有两种情况
- 深度最小为 (y)。(y) 为 (x) 祖先。
- 深度最小为其他点。显然 (x,y) 会被划分到不同子树。
令 ( ext{dis}(x,y)) 为原树上 (x,y) 之间(包括本身)的点数。那么 (x) 满足条件的概率就是 (frac{1}{ ext{dis}(x,y)})。
为啥捏?你会发现选点总方案是一个全排列,而且每一位上填某个数的概率是固定的。所以实际上我们只用考虑路径内部先后关系即可。
答案就是 (sumsum frac{1}{ ext{dis}(x,y)})。
还有一个自然的转化:枚举路径长度然后算方案数。这个就是比较经典的点分治了,但是不是固定求长度为 (k),直接 (mathtt{DP}) 显然超时,可以用 (mathtt{FFT}) 来优化。
代码
#include <cstdio>
#define rep(i,_l,_r) for(register signed i=(_l),_end=(_r);i<=_end;++i)
#define fep(i,_l,_r) for(register signed i=(_l),_end=(_r);i>=_end;--i)
#define erep(i,u) for(signed i=head[u],v=to[i];i;i=nxt[i],v=to[i])
#define efep(i,u) for(signed i=Head[u],v=to[i];i;i=nxt[i],v=to[i])
#define print(x,y) write(x),putchar(y)
template <class T> inline T read(const T sample) {
T x=0; int f=1; char s;
while((s=getchar())>'9'||s<'0') if(s=='-') f=-1;
while(s>='0'&&s<='9') x=(x<<1)+(x<<3)+(s^48),s=getchar();
return x*f;
}
template <class T> inline void write(const T x) {
if(x<0) return (void) (putchar('-'),write(-x));
if(x>9) write(x/10);
putchar(x%10^48);
}
template <class T> inline T Max(const T x,const T y) {if(x>y) return x; return y;}
template <class T> inline T Min(const T x,const T y) {if(x<y) return x; return y;}
template <class T> inline T fab(const T x) {return x>0?x:-x;}
template <class T> inline T gcd(const T x,const T y) {return y?gcd(y,x%y):x;}
template <class T> inline T lcm(const T x,const T y) {return x/gcd(x,y)*y;}
template <class T> inline T Swap(T &x,T &y) {x^=y^=x^=y;}
#include <cmath>
#include <vector>
#include <cstring>
#include <iostream>
using namespace std;
#define int long long
const int maxn=1e5+5,mod=1e9+7;
const double Pi=acos(-1.0);
int n,dfn[maxn],idx,maxSon[maxn],siz[maxn],dep,g[maxn],f[maxn],c[maxn];
bool vis[maxn];
vector <int> e[maxn];
struct cp {
double x,y;
cp operator + (const cp t) const {
return (cp){x+t.x,y+t.y};
}
cp operator - (const cp t) const {
return (cp){x-t.x,y-t.y};
}
cp operator * (const cp t) const {
return (cp){x*t.x-y*t.y,y*t.x+x*t.y};
}
} wn,w,tmp,s[maxn<<2];
void FFT(cp *t,int lim,int op=1) {
static int rev[maxn<<2];
rep(i,0,lim-1) {
rev[i]=(rev[i>>1]>>1)|((i&1)*(lim>>1));
if(i<rev[i]) swap(t[i],t[rev[i]]);
}
for(int mid=1;mid<lim;mid<<=1) {
wn=(cp){cos(Pi/mid),sin(Pi/mid)*op};
for(int i=0,p=(mid<<1);i<lim;i=i+p) {
w=(cp){1,0};
for(int j=0;j<mid;++j,w=w*wn) {
tmp=w*t[i+j+mid];
t[i+j+mid]=t[i+j]-tmp,t[i+j]=t[i+j]+tmp;
}
}
}
}
void dfs(int u,int fa) {
dfn[++idx]=u; siz[u]=1,maxSon[u]=0;
for(int i=0;i<e[u].size();++i) {
int v=e[u][i];
if(v==fa || vis[v]) continue;
dfs(v,u);
siz[u]+=siz[v];
maxSon[u]=Max(maxSon[u],siz[v]);
}
}
int getRoot(int u) {
idx=0,dfs(u,0);
int rt=0;
rep(i,1,idx) {
maxSon[dfn[i]]=Max(maxSon[dfn[i]],idx-siz[dfn[i]]);
if(maxSon[rt]>maxSon[dfn[i]]) rt=dfn[i];
}
return rt;
}
void Dfs(int u,int fa,int len) {
dep=Max(dep,len),++g[len];
for(int i=0;i<e[u].size();++i) {
int v=e[u][i];
if(vis[v] || v==fa) continue;
Dfs(v,u,len+1);
}
}
void polySqr(int *t,int dep) {
// 一个优化 FFT 次数的小 trick:当做平方卷积时可以将实部与虚部都赋值为 val,这样算出来的点值显然实部与虚部相等(设其为 Value)。但 [0,lim-1] 循环由于是虚数相乘,会得到 (0,2*Value^2),所以需要将 s[i].y 除以 2
rep(i,0,dep) s[i].x=s[i].y=t[i];
int lim=1;
while(lim<=(dep*2)) lim<<=1;
FFT(s,lim);
rep(i,0,lim-1) s[i]=s[i]*s[i];
FFT(s,lim,-1);
rep(i,0,dep+dep) t[i]=(int)(s[i].y/lim/2+0.49);
memset(s,0,sizeof(cp)*(lim+5));
}
void calc(int u) {
// 直接合并会导致点分治复杂度不对,所以要用容斥:计算总卷积,显然多的部分就是子树内部的卷积
int len=0; ++f[0];
for(int i=0;i<e[u].size();++i) {
int v=e[u][i];
if(vis[v]) continue;
dep=0;
Dfs(v,u,1);
len=Max(len,dep);
rep(j,0,dep) f[j]=f[j]+g[j];
polySqr(g,dep);
rep(j,0,dep<<1) c[j]=c[j]-g[j];
memset(g,0,sizeof(int)*((dep<<1)+5));
}
polySqr(f,len);
rep(i,0,len<<1) c[i]=c[i]+f[i];
memset(f,0,sizeof(int)*((len<<1)+5));
}
void work(int u) {
calc(u); vis[u]=1;
for(int i=0;i<e[u].size();++i)
if(!vis[e[u][i]]) work(getRoot(e[u][i]));
}
int qkpow(int x,int y) {
int r=1;
while(y) {
if(y&1) r=1ll*r*x%mod;
x=1ll*x*x%mod; y>>=1;
}
return r;
}
signed main() {
n=read(9);
int x,y;
rep(i,1,n-1) x=read(9),y=read(9),e[x].push_back(y),e[y].push_back(x);
maxSon[0]=n;
work(getRoot(1));
int ans=0;
rep(i,0,n-1) ans=(ans+1ll*c[i]%mod*qkpow(i+1,mod-2)%mod)%mod;
rep(i,1,n) ans=1ll*ans*i%mod;
print(ans,'
');
return 0;
}