正题
题目链接:https://www.luogu.com.cn/problem/P5666
题目大意
给出\(n\)个点的一棵树,对于每条边割掉后两棵树重心编号和。
\(1\leq T\leq 5,1\leq n\leq 299995\)
解题思路
编号和,所以应该是要我们枚举点然后求有多少条边割掉后它能当重心。
随便以一个点为根的话,对于一个点,割掉它子树外面的一条边,设割去的连通块大小为\(k\)那么需要满足,以及点\(x\)的最大子节点子树为\(f_x\)
\[\left\{\begin{matrix}2(n-k-s_x)\leq n-S\\2f_x\leq n-k\end{matrix}\right.
\]
移一下项就有
\[n-2s_x\leq k\leq n-2f_x
\]
但是子树里面就很难搞了,因为我们需要维护子树里所有子树的大小,其中一种方法是用线段树合并或者主席树像YbtOJ#662-交通运输这题一样搞。
很麻烦对啊吧,转换一下思路。其实有一个性质就是如果我们选择了原树的重心作为根节点,那么子节点无论如何割掉子树中的边也不会是重心。
所以这样就可以去掉这种麻烦的情况了。
只考虑前面那种,我们需要找到分割大小在\([n-2s_x,n-2f_x]\)这个区间的边,并且还要不在子树内。
如果不考虑不在子树内的话挺好搞,对于根节点到该节点的路径都是\(n-siz_x\),否则是\(siz_x\)丢进树状数组里查询就好了,边往下做边改树状数组就好了。
还要减去子树内的,好像还是要和上面一样用线段树合并?
我们可以用进入子树后的总共答案减去进入子树前的总共答案就是子树里面的答案了
这样好写很多,时间复杂度\(O(n\log n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define lowbit(x) (x&-x)
using namespace std;
const ll N=3e5+10;
struct node{
ll to,next;
}a[N<<1];
ll T,n,ans,tot,rt,u,v;
ll siz[N],f[N],ls[N],t1[N],t2[N];
void addl(ll x,ll y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
void Change(ll *t,ll x,ll val){
x++;
while(x<=n+1){
t[x]+=val;
x+=lowbit(x);
}
return;
}
ll Ask(ll *t,ll x){
ll ans=0;x++;
if(x>=n+1)x=n+1;else if(x<0)x=0;
while(x){
ans+=t[x];
x-=lowbit(x);
}
return ans;
}
void dfs(ll x,ll fa){
siz[x]=1;f[x]=0;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa)continue;
dfs(y,x);siz[x]+=siz[y];
f[x]=max(f[x],siz[y]);
}
if(max(f[x],n-siz[x])<=n/2)rt=x;
return;
}
void calc(ll x,ll fa,bool flag){
Change(t1,siz[fa],-1);
Change(t1,n-siz[x],1);
ll tmp=Ask(t1,n-2*f[x])-Ask(t1,n-2*siz[x]-1);
tmp+=Ask(t2,n-2*f[x])-Ask(t2,n-2*siz[x]-1);
ans+=tmp*x;
ans+=rt*(siz[x]<=n-2*siz[flag?v:u]);
Change(t2,siz[x],1);
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa)continue;
calc(y,x,flag);
}
Change(t1,siz[fa],1);
Change(t1,n-siz[x],-1);
ans-=(Ask(t2,n-2*f[x])-Ask(t2,n-2*siz[x]-1))*x;
}
signed main()
{
scanf("%lld",&T);
while(T--){
memset(ls,0,sizeof(ls));
tot=rt=ans=u=v=0;
scanf("%lld",&n);
for(ll i=1;i<n;i++){
ll x,y;
scanf("%lld%lld",&x,&y);
addl(x,y);addl(y,x);
}
dfs(1,0);dfs(rt,0);
for(ll i=ls[rt];i;i=a[i].next){
ll y=a[i].to;
if(siz[y]>siz[u])v=u,u=y;
else if(siz[y]>siz[v])v=y;
}
memset(t1,0,sizeof(t1));
memset(t2,0,sizeof(t2));
for(ll i=1;i<=n;i++)Change(t1,siz[i],1);
for(ll i=ls[rt];i;i=a[i].next)
calc(a[i].to,rt,(a[i].to==u));
printf("%lld\n",ans);
}
return 0;
}