个人感觉这题称不上毒瘤。
首先看到选一条路径之类的字眼可以轻松想到点分治,也就是我们每次取原树的重心 (r) 并将路径分为经过重心和不经过重心两类路径。对于第二类路径我们肯定可以在对 (r) 所有子树进一步分治的过程中将其变成第一类路径,因此我们只用考虑怎样计算第一类路径的贡献即可。
显然对于一条第一类路径 (x o y),我们要将其拆成 (x o r) 和 (r o y) 两部分分别处理,考虑怎样合并这两部分的贡献,我们记 (dep_x) 为 (x o r) 路径上点的个数,(sum_x) 表示 (x o r) 上所有点的权值之和,那么对于 (x o r) 上某个点 (z) 而言,它对和的贡献即为 ((dep_x-dep_z+1) imes a_z),同理对于 (r o y) 上某个点 (z),它对和的贡献即为 ((dep_x-1+dep_z) imes a_z),那么我们考虑再设一个 (sum1_x) 表示 (x o r) 路径上所有点的 ((dep_x-dep_z+1) imes a_z),(sum2_x) 表示 (x o r) 路径上所有点的 (dep_z imes a_z) 之和。对于 (x) 的某个儿子 (y),显然有 (sum1_y=sum1_x+sum_y,sum2_y=sum2_x+dep_ya_y)。
因此上面 (x o y) 路径的权值就可以写作 (sum1_x+(dep_x-1)sum_y+sum2_y),但注意到 (r) 既在 (x o r) 的路径上也在 (r o y) 的路径上,它的贡献被计算了两次,因此需减去 (dep_xa_r),得到 (sum1_x+(dep_x-1)sum_y+sum2_y-dep_xa_r),考虑怎样快速维护这个东西,我们显然要枚举 (x,y) 中的一个并快速维护另一个变量的决策,那么我们枚举什么好呢?注意到如果我们枚举 (y),那么我们相当于求若干条形如 (f_x(t)=(dep_x-1)t+sum1_x-dep_xa_r) 的直线在 (t=sum_y) 处的最大值,而 (sum_y) 的值域过大不好直接李超线段树,要写个动态凸壳,过于毒瘤,因此考虑枚举 (x),这样相当于我们要求若干条形如 (f_y(t)=sum_y·t+sum2_y) 的直线在 (t=dep_x-1) 处的最大值,这个值域是在 ([0,n-1]) 范围内的,就可以李超线段树维护了。
算下复杂度,点分 1log,李超线段树全局插入是 1log 的,因此总复杂度 2log,可以通过此题。
最后说几个注意点:
- 在点分治过程中注意要正反各跑一遍,因为假设按顺序枚举的子树依次是 (t_1,t_2,cdots,t_m),那么对于某个 (i<j),有可能是 (xin t_i,yin t_j),也有可能 (xin t_j,yin t_i)。
- 每次点分完一个点后要清空李超树,但重新建树复杂度过高,因此可以考虑将修改的点记录在一个
std::vector
中每次将vector
中的点的最大优势线段赋为空即可。
const int MAXN=1.5e5;
const ll INF=0x3f3f3f3f3f3f3f3fll;
int n,a[MAXN+5],hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
struct line{
ll k,b;
line(ll _k=0,ll _b=-INF):k(_k),b(_b){}
ll get(int x){return 1ll*k*x+b;}
};
struct node{int l,r;line v;} s[MAXN*4+5];
void build(int k,int l,int r){
s[k].l=l;s[k].r=r;if(l==r) return;
int mid=l+r>>1;build(k<<1,l,mid);build(k<<1|1,mid+1,r);
}
vector<int> opts;
void modify(int k,line v){
int mid=s[k].l+s[k].r>>1;
ll l1=s[k].v.get(s[k].l),r1=s[k].v.get(s[k].r),m1=s[k].v.get(mid);
ll l2=v.get(s[k].l),r2=v.get(s[k].r),m2=v.get(mid);
if(l1>=l2&&r1>=r2) return;
if(l2>=l1&&r2>=r1) return s[k].v=v,opts.pb(k),void();
if(m2>=m1){
if(l2<l1) return modify(k<<1,s[k].v),s[k].v=v,void();
else return modify(k<<1|1,s[k].v),s[k].v=v,void();
} else {
if(l2<l1) return modify(k<<1|1,v),void();
else return modify(k<<1,v),void();
}
}
ll query(int k,int x){
if(s[k].l==s[k].r) return s[k].v.get(x);int mid=s[k].l+s[k].r>>1;
return max(s[k].v.get(x),(x<=mid)?query(k<<1,x):query(k<<1|1,x));
}
bool vis[MAXN+5];int siz[MAXN+5],mx[MAXN+5],cent=0,subsiz[MAXN+5];
ll ans=0,sum[MAXN+5],_sum[MAXN+5],__sum[MAXN+5];int dep[MAXN+5];
void findcent(int x,int f,int tot){
siz[x]=1;mx[x]=0;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f||vis[y]) continue;
findcent(y,x,tot);siz[x]+=siz[y];
chkmax(mx[x],siz[y]);
} chkmax(mx[x],tot-siz[x]);
if(mx[x]<mx[cent]) cent=x;
}
vector<int> pts;
void getdis(int x,int f){
pts.pb(x);
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f||vis[y]) continue;
dep[y]=dep[x]+1;sum[y]=sum[x]+a[y];
_sum[y]=_sum[x]+sum[y];
__sum[y]=__sum[x]+1ll*a[y]*dep[y];
getdis(y,x);
}
}
void divcent(int x){
vis[x]=1;chkmax(ans,a[x]);dep[x]=1;
sum[x]=_sum[x]=__sum[x]=a[x];
modify(1,line(sum[x],__sum[x]));
stack<int> stk;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(vis[y]) continue;
sum[y]=sum[x]+a[y];dep[y]=2;
_sum[y]=sum[x]+a[x]+a[y];
__sum[y]=sum[x]+(a[y]<<1);
pts.clear();getdis(y,x);subsiz[y]=pts.size();
for(int i=0;i<pts.size();i++){
int z=pts[i];
chkmax(ans,_sum[z]+query(1,dep[z]-1)-1ll*a[x]*dep[z]);
}
for(int i=0;i<pts.size();i++){
int z=pts[i];
modify(1,line(sum[z],__sum[z]));
} stk.push(y);
}
for(int i=0;i<opts.size();i++) s[opts[i]].v=line();
opts.clear();
while(!stk.empty()){
int y=stk.top();stk.pop();
pts.clear();getdis(y,x);
for(int i=0;i<pts.size();i++){
int z=pts[i];
chkmax(ans,_sum[z]+query(1,dep[z]-1)-1ll*a[x]*dep[z]);
}
for(int i=0;i<pts.size();i++){
int z=pts[i];
modify(1,line(sum[z],__sum[z]));
}
} chkmax(ans,_sum[x]+query(1,0)-a[x]);
for(int i=0;i<opts.size();i++) s[opts[i]].v=line();
opts.clear();
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(vis[y]) continue;
cent=0;findcent(y,x,subsiz[y]);
divcent(cent);
}
}
int main(){
scanf("%d",&n);mx[0]=1e9;build(1,0,n);
for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
findcent(1,0,n);divcent(cent);printf("%lld
",ans);
return 0;
}