枚举每条树边,将其断开,那么两侧肯定取带权重心最优。
考虑如何求出每个子树的重心,枚举其所有儿子,通过重量关系就可以判断出重心位于哪棵子树。
然后将那棵子树的重心暴力往上爬即可,因为每个点作为重心肯定是一段连续的链,所以复杂度为$O(n)$。
然后就是如何求出砍掉每棵子树之后剩下的部分的重心。
设当前点到根的路径为关键路径,那么可以通过二分求出重心在关键路径上哪个点的子树里。
对于那棵子树,重心要么是它本身,要么在它最重的子树里,要么在次重的子树里。
在线段树上按dfs序维护区间内子树重量的最大值,即可用线段树完成重心的查询,时间复杂度$O(log n)$。
总时间复杂度$O(nlog n)$。
#include<cstdio> #include<algorithm> using namespace std; typedef long long ll; const int N=500010,BUF=12000000; char Buf[BUF],*buf=Buf; int n,i,x,y,w[N],g[N],v[N<<1],nxt[N<<1],ed; int f[N],d[N],size[N],son[N],top[N],st[N],en[N],id[N],dfn,q[N],cq; int fir[N],sec[N],center[N],vip[N]; ll sum[N],sw[N],val[1050000],sd[N],su[N],ans=1LL<<60; inline void read(int&a){for(a=0;*buf<48;buf++);while(*buf>47)a=a*10+*buf++-48;} inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;} void dfs(int x){ size[x]=1; for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){ f[v[i]]=x,d[v[i]]=d[x]+1; dfs(v[i]),size[x]+=size[v[i]]; if(size[v[i]]>size[son[x]])son[x]=v[i]; } } void dfs2(int x,int y){ id[st[x]=++dfn]=x;top[x]=y; if(son[x])dfs2(son[x],y); for(int i=g[x];i;i=nxt[i])if(v[i]!=son[x]&&v[i]!=f[x])dfs2(v[i],v[i]); en[x]=dfn; } inline int dis(int x,int y){ int t=d[x]+d[y]; for(;top[x]!=top[y];x=f[top[x]])if(d[top[x]]<d[top[y]])swap(x,y); if(d[x]>d[y])swap(x,y); return t-2*d[x]; } inline void cal(int x){ int i=center[fir[x]]; ll t=sum[fir[x]]+sd[x]-sd[fir[x]]-sw[fir[x]]+(sw[x]-sw[fir[x]])*(d[i]-d[x]); while(2*sw[i]<sw[x])t+=2*sw[i]-sw[x],i=f[i]; center[x]=i; sum[x]=t; } void dfs3(int x){ sw[x]=w[x]; for(int i=g[x];i;i=nxt[i]){ int y=v[i]; if(y==f[x])continue; dfs3(y); sw[x]+=sw[y]; if(sw[y]>sw[fir[x]])sec[x]=fir[x],fir[x]=y; else if(sw[y]>sw[sec[x]])sec[x]=y; sd[x]+=sd[y]+sw[y]; } if(2*sw[fir[x]]<=sw[x]){ center[x]=x; sum[x]=sd[x]; return; } cal(x); } void dfs4(int x){ if(f[x]){ int y=f[x]; su[x]=su[y]+sd[y]-sd[x]-2*sw[x]+sw[1]; } for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x])dfs4(v[i]); } inline int lower(ll x){ int l=2,r=cq,mid,t=1; while(l<=r)if(2*(sw[q[mid=(l+r)>>1]]-x)>=sw[1]-x)l=(t=mid)+1;else r=mid-1; return q[t]; } void build(int x,int a,int b){ if(a==b){val[x]=sw[id[a]]*2;return;} int mid=(a+b)>>1; build(x<<1,a,mid),build(x<<1|1,mid+1,b); val[x]=max(val[x<<1],val[x<<1|1]); } int ask(int x,int a,int b,int c,int d,ll p){ if(val[x]<p)return 0; if(a==b)return a; int mid=(a+b)>>1,t=0; if(d>mid)t=ask(x<<1|1,mid+1,b,c,d,p); if(t)return t; if(c<=mid)t=ask(x<<1,a,mid,c,d,p); return t; } void dfs5(int x){ if(f[x])vip[x]=lower(sw[x]); q[++cq]=x; for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x])dfs5(v[i]); cq--; } inline void solve(int x){ int t=vip[x],y,z=0; if(st[fir[t]]<=st[x]&&en[x]<=en[fir[t]])y=sec[t];else y=fir[t]; if(y)z=ask(1,1,n,st[y],en[y],sw[1]-sw[x]); if(!z)z=t;else z=id[z]; ans=min(ans,sum[x]+sd[z]+su[z]-sd[x]-sw[x]*dis(x,z)); } int main(){ fread(Buf,1,BUF,stdin);read(n); for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x); for(i=1;i<=n;i++)read(w[i]); dfs(1);dfs2(1,1); dfs3(1);dfs4(1); build(1,1,n);dfs5(1); for(i=2;i<=n;i++)solve(i); return printf("%lld",ans),0; }