You are given a weighted tree (undirected connected graph with no cycles, loops or multiple edges) with nn vertices. The edge {uj,vj}{uj,vj} has weight wjwj. Also each vertex ii has its own value aiai assigned to it.
Let's call a path starting in vertex uu and ending in vertex vv, where each edge can appear no more than twice (regardless of direction), a 2-path. Vertices can appear in the 2-path multiple times (even start and end vertices).
For some 2-path pp profit Pr(p)=∑v∈distinct vertices in pav−∑e∈distinct edges in pke⋅wePr(p)=∑v∈distinct vertices in pav−∑e∈distinct edges in pke⋅we, where keke is the number of times edge ee appears in pp. That is, vertices are counted once, but edges are counted the number of times they appear in pp.
You are about to answer mm queries. Each query is a pair of vertices (qu,qv)(qu,qv). For each query find 2-path pp from ququ to qvqv with maximal profit Pr(p)Pr(p).
The first line contains two integers nn and qq (2≤n≤3⋅1052≤n≤3⋅105, 1≤q≤4⋅1051≤q≤4⋅105) — the number of vertices in the tree and the number of queries.
The second line contains nn space-separated integers a1,a2,…,ana1,a2,…,an (1≤ai≤109)(1≤ai≤109) — the values of the vertices.
Next n−1n−1 lines contain descriptions of edges: each line contains three space separated integers uiui, vivi and wiwi (1≤ui,vi≤n1≤ui,vi≤n, ui≠viui≠vi, 1≤wi≤1091≤wi≤109) — there is edge {ui,vi}{ui,vi} with weight wiwi in the tree.
Next qq lines contain queries (one per line). Each query contains two integers quiqui and qviqvi (1≤qui,qvi≤n)(1≤qui,qvi≤n) — endpoints of the 2-path you need to find.
For each query print one integer per line — maximal profit Pr(p)Pr(p) of the some 2-path pp with the corresponding endpoints.
Solution
对于x到y的路径上的每条边很显然只能走一次,我们要考虑的是路径上的点向外走一个Two-Path可以获得的最大权值
用树形dp进行预处理
分析一下可以发现,对一个节点的贡献可以来自祖先,后代,和兄弟(即其后代)
所以分三个dp来解决
对于当前u,子树对其贡献(dp1)不必多说
兄弟对其贡献(dp3)是其父节点的dp1-u对父节点的贡献
祖先对其贡献(dp3)是u到祖先这条链及其这条链的分支到兄弟节点可以获得的最大权值(均可以看做上对下的贡献)
预处理出这些之后不难想到统计答案的方式
注意,在对沿路上的节点统计兄弟的贡献时不要忘了lca的子节点的兄弟贡献不能加重
用倍增lca可以找到lca到x,y的链上的子节点
1 //xcj is handsome 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 using namespace std; 6 int n,q,cnt,tot; 7 long long ans; 8 int fa[300005][21]; 9 int head[300005]; 10 int dep[300005]; 11 int cnn[300005]; 12 long long sum[300005]; 13 long long bro[300005]; 14 long long val[300005]; 15 long long dis[300005]; 16 long long dp1[300005]; 17 long long dp2[300005]; 18 long long dp3[300005]; 19 struct Edge{ 20 int fr; 21 int to; 22 long long val; 23 int nxt; 24 }edge[600005]; 25 void init(){ 26 dep[1]=1; 27 memset(head,-1,sizeof(head)); 28 } 29 void addedge(int f,int t,long long v){ 30 cnt++; 31 edge[cnt].fr=f; 32 edge[cnt].to=t; 33 edge[cnt].val=v; 34 edge[cnt].nxt=head[f]; 35 head[f]=cnt; 36 } 37 void dfs1(int u){ 38 sum[u]=sum[fa[u][0]]+val[u]; 39 for(int i=1;i<=20;i++){ 40 fa[u][i]=fa[fa[u][i-1]][i-1]; 41 } 42 for(int i=head[u];i!=-1;i=edge[i].nxt){ 43 int v=edge[i].to; 44 if(v==fa[u][0])continue; 45 cnn[v]=i; 46 fa[v][0]=u;dep[v]=dep[u]+1; 47 dis[v]=dis[u]+edge[i].val; 48 dfs1(v); 49 dp1[u]+=max(dp1[v]+val[v]-edge[i].val*2,0ll); 50 } 51 for(int i=head[u];i!=-1;i=edge[i].nxt){ 52 int v=edge[i].to; 53 if(v==fa[u][0])continue; 54 dp2[v]=dp1[u]-max(dp1[v]+val[v]-edge[i].val*2,0ll); 55 } 56 } 57 void dfs2(int u){ 58 bro[u]=bro[fa[u][0]]+dp2[u]; 59 for(int i=head[u];i!=-1;i=edge[i].nxt){ 60 int v=edge[i].to; 61 if(v==fa[u][0])continue; 62 dp3[v]=max(dp3[u]+val[u]-edge[i].val*2+dp2[v],0ll); 63 dfs2(v); 64 } 65 } 66 int lca(int x,int y){ 67 if(dep[y]>dep[x]){ 68 swap(x,y); 69 } 70 for(int i=20;i>=0;i--){ 71 if(dep[fa[x][i]]>=dep[y])x=fa[x][i]; 72 } 73 if(x==y)return x; 74 int ret; 75 for(int i=20;i>=0;i--){ 76 if(fa[x][i]!=fa[y][i]){ 77 x=fa[x][i]; 78 y=fa[y][i]; 79 }else{ 80 ret=fa[x][i]; 81 } 82 } 83 return ret; 84 } 85 int main(){ 86 init(); 87 scanf("%d%d",&n,&q); 88 for(int i=1;i<=n;i++){ 89 scanf("%I64d",&val[i]); 90 } 91 for(int i=1;i<n;i++){ 92 int u,v,w; 93 scanf("%d%d%d",&u,&v,&w); 94 addedge(u,v,(long long)w); 95 addedge(v,u,(long long)w); 96 } 97 dfs1(1); 98 dfs2(1); 99 for(int i=1;i<=q;i++){ 100 int u,v,f;ans=0; 101 scanf("%d%d",&u,&v); 102 f=lca(u,v); 103 ans-=dis[u]+dis[v]-2*dis[f]; 104 ans+=sum[u]+sum[v]-sum[f]-sum[fa[f][0]]; 105 if(dep[u]>dep[v])swap(u,v); 106 if(u==f){ 107 ans+=dp3[u]+dp1[v]; 108 ans+=bro[v]-bro[u]; 109 }else{ 110 ans+=dp3[f]+dp1[u]+dp1[v]; 111 ans+=bro[u]+bro[v]; 112 for(int i=20;i>=0;i--){ 113 if(dep[fa[u][i]]>dep[f])u=fa[u][i]; 114 if(dep[fa[v][i]]>dep[f])v=fa[v][i]; 115 } 116 ans-=bro[u]+bro[v]; 117 ans+=dp2[u]; 118 ans-=max(dp1[v]+val[v]-2*edge[cnn[v]].val,0ll); 119 } 120 printf("%I64d ",ans); 121 } 122 return 0; 123 }