不妨设一个节点 (v) 的权值为 (val_v) ,向其子树最多可以得到的链的最大权值和为 (down_v) ,在子树中最长链的权值为 (down_v) 。
然后就可以 (dfs) 处理出 (down_v) (必须由 (v) 节点向下)和 (down_v) (可以由一棵子树从下往上经过 (v) 节点然后转向另一棵子树,也可以是某棵子树的内部的最长链)。
开一个队列,遍历每个节点 (u) ,求 (u) 和其儿子 (v) 这条边为分界线的两棵树的内部最长链。这时就相当于把 (v) 节点和 (u) 以及 (u) 的其余子节点都分开了。
先求出 (u) 以及 (u) 除了 (v) 以外的儿子节点所形成的树的最长链,记为 (outside) 。
设 (up_u) 为 (u) 向上可连成的最长链(不包括 (u) 本身), (predown_x) 和 (ppredown_x) 为 (u) 的儿子的 (down_x) 的前缀最大和次大, (prebest_x) 为儿子的前缀最大, (sufdown_x,ssufdown_x,sufbest_x) 分别为相应的后缀 (down_x) 最大和次大以及 (best_x) 最大。
然后就可以考虑 (outside) 可以怎样连成了。
设当前枚举到第 (i) 个儿子,则 (outside) 是下面五种情况的最大值: (up_u+val_u+max {predown_{i-1},sufdown_{i+1} },val_u+predown_{i-1}+sufdown_{i+1},val_u+predown_{i-1}+ppredown_{i-1},val_u+sufdown_{i+1}+ssufdown_{i+1},max {prebest_{i-1},sufbest_{i+1} }) 。
然后将答案对 (outside+best_v) 取 (max) 即可( (v) 是第 (i) 个儿子)。
当然,还需要更新 (up_v) ,计算方法是 (up_v=val_u+max {up_u,max {predown_{i-1},sufdown_{i+1} } }) 。
注意 (up_v) 不包括 (val_v) 。
如有疑问,评论区见。
#include <bits/stdc++.h>
using namespace std;
inline int read()
{
int ret=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
ret=(ret<<1)+(ret<<3)+ch-'0';
ch=getchar();
}
return ret*f;
}
int n,u,v,num,head[200005];
long long val[100005],bst[100005],dwn[100005],prebest[100005],sufbest[100005],predown[100005],ppredown[100005],sufdown[100005],ssufdown[100005],upp[100005],ans;
struct edge
{
int ver,next;
}e[200005];
inline void adde(int u,int v)
{
e[num].ver=v;
e[num].next=head[u];
head[u]=num++;
}
void dfs(int u,int fa)
{
long long maxn1=0,maxn2=0;
for(int i=head[u];i!=-1;i=e[i].next)
{
int v=e[i].ver;
if(v==fa)
continue;
dfs(v,u);
if(dwn[v]>maxn1)
{
maxn2=maxn1;
maxn1=dwn[v];
}
else if(dwn[v]>maxn2)
maxn2=dwn[v];
bst[u]=max(bst[u],bst[v]);
}
dwn[u]=maxn1+val[u];
bst[u]=max(bst[u],maxn1+maxn2+val[u]);
}
void solve()
{
queue<pair<int,int> > q;
q.push(make_pair(1,0));
while(!q.empty())
{
int u=q.front().first,fa=q.front().second;
q.pop();
vector<int> child;
child.push_back(0);
for(register int i=head[u];i!=-1;i=e[i].next)
{
int v=e[i].ver;
if(v!=fa)
child.push_back(v);
}
int siz=child.size();
prebest[0]=0;
predown[0]=0;
ppredown[0]=0;
for(register int i=1;i<siz;i++)
{
int v=child[i];
prebest[i]=max(prebest[i-1],bst[v]);
predown[i]=predown[i-1],ppredown[i]=ppredown[i-1];
if(dwn[v]>predown[i])
{
ppredown[i]=predown[i];
predown[i]=dwn[v];
}
else if(dwn[v]>ppredown[i])
ppredown[i]=dwn[v];
}
sufbest[siz]=0;
sufdown[siz]=0;
ssufdown[siz]=0;
for(register int i=siz-1;i;i--)
{
int v=child[i];
sufbest[i]=max(sufbest[i+1], bst[v]);
sufdown[i]=sufdown[i+1], ssufdown[i] = ssufdown[i+1];
if(dwn[v]>sufdown[i])
{
ssufdown[i]=sufdown[i];
sufdown[i]=dwn[v];
}
else if(dwn[v]>ssufdown[i])
ssufdown[i]=dwn[v];
}
for(register int i=1;i<siz;i++)
{
int v=child[i];
long long otsd=upp[u]+val[u]+max(predown[i-1],sufdown[i+1]);
otsd=max(otsd,val[u]+predown[i-1]+ppredown[i-1]);
otsd=max(otsd,val[u]+sufdown[i+1]+ssufdown[i+1]);
otsd=max(otsd,val[u]+predown[i-1]+sufdown[i+1]);
otsd=max(otsd,max(prebest[i-1],sufbest[i+1]));
ans=max(ans,otsd+bst[v]);
}
for(register int i=1;i<siz;i++)
{
int v=child[i];
upp[v]=val[u]+max(upp[u],max(predown[i-1],sufdown[i+1]));
q.push(make_pair(v,u));
}
}
}
int main()
{
n=read();
for(register int i=1;i<=n;i++)
val[i]=read();
memset(head,-1,sizeof(head));
for(register int i=1;i<n;i++)
{
u=read();
v=read();
adde(u,v);
adde(v,u);
}
dfs(1,0);
solve();
printf("%lld
",ans);
return 0;
}