蜜汁树形dp...
首先分析一下:他要求一条边至多只能经过两次,那么很容易会发现:从x到y这一条路径上的所有边都只会被经过一次。(如果过去再回来那么还要过去,这样就三次了,显然不合法)
那么其他能产生贡献的部分就只有一下几个部分:x,y的子树内部,LCA(x,y)的上半部分的树以及x-y路径上的点向外延伸所形成的部分
这三部分互相独立又互相关联,所以我们设计三个dp对他们进行转移
记dp1[x]代表x的子树内所形成的的贡献,dp3[x]表示x以上的树所形成的贡献(包括x的兄弟节点)
这样就设计出了第一个和第三个状态
至于第二个,我们可以发现这个情况等价于路径上所有点向他的所有兄弟节点去跑,这样延伸出来的一种情况。
那么我们设计dp2[x]代表x的兄弟节点对x的贡献
接下来我们考虑转移:
首先,dp1非常好转移,只需向下dfs,每次回溯时只要能产生正的贡献就向上更新,同时记录每个点是否可以向上更新即可
当dp1出来了之后,dp2也就很好转移了,因为如果父节点的dp1没有利用这个节点进行更新,那么这个节点的dp2就是他父节点的dp1
如果dp1利用了这个节点进行更新,那就将dp1减掉这个节点提供的贡献赋给dp2即可
而dp3,很显然dp3要分为两部分,一部分是父节点向上,一部分是兄弟节点,兄弟节点部分就是dp2,而父节点向上那就是父节点的dp3,这也就完成了更新
这样三个dp就维护出来了
如果对概念不是特别清楚,画几个图来理解一下:
那么更新完这三个,查询也就变得简单了:首先统计x-y路径上的部分,然后统计x子树内,y子树内,LCA(x,y)以上的部分,以及x-y路径上的点向外延伸的部分,而这部分可以在树链上用前缀和维护。
但是这里有个小问题:由于x和y在跳到LCA上时会跳到LCA的两个子节点上,那么对这两个子节点,我们不能加两次兄弟节点的贡献(这样就加重了),所以我们去掉一部分即可。
贴代码:
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ll long long
using namespace std;
struct Edge
{
int next;
int to;
ll val;
}edge[600005];
int head[300005];
int f[300005][30];
ll dp1[300005];
ll dp2[300005];
ll dp3[300005];
ll fv[300005];
bool used1[300005];
ll dis[300005];//边权距离
ll d[300005];//点权距离
ll v[300005];
ll s[300005];
int dep[300005];
int cnt=1;
int n,q;
void init()
{
memset(head,-1,sizeof(head));
cnt=1;
}
void add(int l,int r,ll w)
{
edge[cnt].next=head[l];
edge[cnt].to=r;
edge[cnt].val=w;
head[l]=cnt++;
}
void dfs(int x,int fx)//处理dp1
{
f[x][0]=fx;
dep[x]=dep[fx]+1;
for(int i=head[x];i!=-1;i=edge[i].next)
{
int to=edge[i].to;
if(to==fx)
{
continue;
}
fv[to]=edge[i].val;
dis[to]=dis[x]+edge[i].val;
d[to]=d[x]+v[to];
dfs(to,x);
if(dp1[to]+v[to]-2*edge[i].val>=0)
{
dp1[x]+=dp1[to]+v[to]-2*edge[i].val;
used1[to]=1;
}
}
for(int i=head[x];i!=-1;i=edge[i].next)
{
int to=edge[i].to;
if(to==fx)
{
continue;
}
if(!used1[to])
{
dp2[to]=dp1[x];
}else
{
dp2[to]=dp1[x]-(dp1[to]+v[to]-2*edge[i].val);
}
}
}
void redfs(int x,int fx)
{
s[x]+=dp2[x];
for(int i=head[x];i!=-1;i=edge[i].next)
{
int to=edge[i].to;
if(to==fx)
{
continue;
}
dp3[to]=max((ll)0,dp3[x]+v[x]-2*edge[i].val+dp2[to]);
s[to]+=s[x];
redfs(to,x);
}
}
void getf()
{
for(int i=1;i<=25;i++)
{
for(int j=1;j<=n;j++)
{
f[j][i]=f[f[j][i-1]][i-1];
}
}
}
int LCA(int x,int y)
{
if(dep[x]>dep[y])
{
swap(x,y);
}
for(int i=25;i>=0;i--)
{
if(dep[f[y][i]]>=dep[x])
{
y=f[y][i];
}
}
if(x==y)
{
return x;
}
int ret;
for(int i=25;i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}else
{
ret=f[x][i];
}
}
return ret;
}
ll cal(int x,int y)
{
ll ret=0;
if(dep[x]>dep[y])
{
swap(x,y);
}
int f1=LCA(x,y);
if(f1!=1)
{
ret+=d[x]+d[y]-d[f1]-d[f[f1][0]];
ret-=dis[x]+dis[y]-2*dis[f1];
}else
{
ret+=d[x]+d[y]-d[f1];
ret-=dis[x]+dis[y]-dis[f1];
}
if(x==f1)
{
ret+=dp1[y];
ret+=dp3[x];
ret+=s[y];
ret-=s[x];
}else
{
ret+=dp1[x];
ret+=dp1[y];
ret+=dp3[f1];
int ff1=x,ff2=y;
for(int i=25;i>=0;i--)
{
if(dep[f[ff1][i]]>dep[f1])
{
ff1=f[ff1][i];
}
if(dep[f[ff2][i]]>dep[f1])
{
ff2=f[ff2][i];
}
}
ret+=s[x]-s[ff1];
ret+=s[y]-s[ff2];
ret+=dp2[ff1];
if(used1[ff2])
{
ret-=dp1[ff2]+v[ff2]-2*fv[ff2];
}
}
return ret;
}
inline int read()
{
int f=1,x=0;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int main()
{
n=read(),q=read();
init();
for(int i=1;i<=n;i++)
{
v[i]=(ll)read();
}
for(int i=1;i<n;i++)
{
int x=read(),y=read(),z=read();
add(x,y,(ll)z);
add(y,x,(ll)z);
}
d[1]=v[1];
dfs(1,1);
getf();
redfs(1,1);
for(int i=1;i<=q;i++)
{
int x=read(),y=read();
printf("%lld
",cal(x,y));
}
return 0;
}