考虑计算一个重心被统计进答案的次数。
设当前枚举 (x) 作为切去某条边 ((u,v)) 后(其中 (dep_u<dep_v))树 (T_v)可能的重心,(g_x) 为 (x) 的最大子树的大小,(f_x) 为子树 (x) 的大小,如果点 (x) 不为根,就需要满足以下约束:
[egin{cases}
2g_xleq |T_v|
\
2(|T_v|-s_x)leq |T_v|
\
(u,v)
otin ext{Subtree}(x)
end{cases}
]
可以解得 (2g_xleq|T_v|leq2s_x),又有 (|T_v|=n-|T_u|),所以还可以得到 (n-2f_x leq|T_u|leq n-2g_x)。看上去似乎可以直接一个 ( ext{BIT}) 在树上维护?但是还有一个割边不在 (x) 的子树内的约束。那就考虑用总的答案减去割边在 (x) 的子树内的答案即可,这样就需要两个 ( ext{BIT})。
如果点 (x) 就是根呢?其实也是一样的,只不过现在的重心固定为根而已。可以在树上 ( ext{dfs}) 的时候顺便处理一下。
时间复杂度为 (O(n log n))。
#include<cstdio>
#include<assert.h>
typedef long long ll;
ll ans=0;
int n,cnt,rt,son1,son2;
int h[300005],to[600005],ver[600005];
int s1[300005],s2[300005],size[300005],max_size[300005],bel[300005];
inline int read() {
register int x=0,f=1;register char s=getchar();
while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9' ) {x=x*10+s-'0';s=getchar();}
return x*f;
}
inline int max(const int &x,const int &y) {return x>y? x:y;}
inline void add(int x,int y) {to[++cnt]=y; ver[cnt]=h[x]; h[x]=cnt;}
inline void add_val(int *s,int x,int val) {++x; for(;x<=n+1;x+=x&(-x)) s[x]+=val;}
inline int ask(int *s,int x) {int res=0; ++x; if(x<0) printf("%d
",x); for(;x;x-=x&(-x)) res+=s[x]; return res;}
inline void prework(int x,int fa) {
size[x]=1; max_size[x]=0;
for(register int i=h[x];i;i=ver[i]) {
int y=to[i]; if(y==fa) continue;
prework(y,x); size[x]+=size[y];
if(max_size[x]<size[y]) max_size[x]=size[y];
}
if(!rt&&max(n-size[x],max_size[x])<=n/2) rt=x;
}
inline void dfs(int x,int fa) {
if(x!=rt) {
add_val(s1,size[fa],-1);
add_val(s1,n-size[x],1);
ans+=x*(ll)(ask(s1,n-2*max_size[x])-ask(s1,n-2*size[x]-1));
ans+=x*(ll)(ask(s2,n-2*max_size[x])-ask(s2,n-2*size[x]-1));
if(!bel[x]&&bel[fa]) bel[x]=1;
ans+=rt*(size[x]<=n-2*size[bel[x]? son2:son1]);
}//S<=n-2*size[u]
for(register int i=h[x];i;i=ver[i]) {int y=to[i]; if(y==fa) continue; dfs(y,x);}
add_val(s2,size[x],1);
if(x!=rt) {
add_val(s1,size[fa],1);
add_val(s1,n-size[x],-1);
ans-=x*(ll)(ask(s2,n-2*max_size[x])-ask(s2,n-2*size[x]-1));
}
}
int main() {
int T=read();
while(T--) {
n=read(); cnt=0;
for(register int i=1;i<=n+1;++i) {h[i]=bel[i]=s1[i]=s2[i]=0;}
for(register int i=1;i<n;++i) {
int x=read(),y=read();
add(x,y); add(y,x);
}
ans=rt=son1=son2=0; prework(1,-1); prework(rt,-1);
for(register int i=h[rt];i;i=ver[i]) {
int y=to[i];
if(size[son1]<=size[y]) {son2=son1; son1=y;}
else {son2=(size[son2]<size[y]? y:son2);}
}//size[son1]>size[y]
for(register int i=1;i<=n;++i) add_val(s1,size[i],1);
bel[son1]=1; dfs(rt,-1);
printf("%lld
",ans);
}
return 0;
}