题意是给了 n 个点的树,会有m条链条 链接两个点,计算出他们没有公共点的最大价值, 公共点时这样计算的只要在他们 lca 这条链上有公共点的就说明他们相交
dp[i]为这个点包含的子树所能得到的最大价值
sum[i]表示这个点没有选择经过i这个点链条的总价值
两种选择
这个点没有被选择
dp[i]=sum[i]=sigma(dp[k])k为i的子树
选择了某个链
假设这条链 为(tyuijk)
那么dp[i]=(sum[i]-dp[u]-dp[j])+(sum[j]-dp[k])+dp[k] +(sum[u]-dp[y])+(sum[y]-dp[t])+sum[t];
整理后发现 dp[i]=sum[i] +(sum[j]-dp[j])+(sum[k]-dp[k])+(sum[u]-dp[u])+(sum[y]-dp[y])+(sum[t]-dp[t]);
使用lca计算出每条链的最近公共祖先,在这个最近公共祖先上判断是否使用这条链,还有我们可以使用时间戳加树状数组来求得sum和dp
#include <iostream> #include <algorithm> #include <string.h> #include <cstdio> #include <vector> using namespace std; const int maxn=100000+10; int to[maxn*2],nx[maxn*2],H[maxn*2],numofedg,timoflook; int fa[maxn][20],first[maxn],last[maxn],depth[maxn]; void addedg(int u, int v) { numofedg++; to[numofedg]=v; nx[numofedg]=H[u]; H[u]=numofedg; numofedg++; to[numofedg]=u; nx[numofedg]=H[v]; H[v]=numofedg; } void dfs(int cur, int per, int dep) { first[cur]=++timoflook; depth[cur]=dep; fa[cur][0]=per; for(int i=1; i<20; i++) { fa[cur][i]=fa[ fa[cur][i-1] ][ i-1 ]; } for(int i=H[cur]; i; i=nx[i]) { if(to[i]==per)continue; dfs(to[i],cur,dep+1); } last[cur]=++timoflook; } int getlca(int u,int v) { if(depth[u]<depth[v])swap(u,v); for(int i=19; i>=0; i--) { if(depth[fa[u][i]]>=depth[v]) u=fa[u][i]; if(u==v)return u; } for(int i=19; i>=0; i--) { if(fa[u][i]!=fa[v][i]) { u=fa[u][i]; v=fa[v][i]; } } return fa[u][0]; } struct Edg { int u,v,lca,val; }P[maxn]; vector<int>E[maxn]; int dp[maxn],sum[maxn],CS[maxn*3],CD[maxn*3]; int lowbit(int x) { return x&-x; } void add(int x, int d, int *C) { while(x<=timoflook) { C[x]+=d; x+=lowbit(x); } } int getsum(int x, int *C) { int ret=0; while(x>0) { ret+=C[x]; x-=lowbit(x); } return ret; } void solve(int cur, int per) { dp[cur]=sum[cur]=0; for(int i=H[cur]; i; i=nx[i]) { if(to[i]==per)continue; solve(to[i],cur); sum[cur]+=dp[to[i]]; } dp[cur]=sum[cur]; for(int i=0; i<E[cur].size(); i++) { int id=E[cur][i]; int u=P[id].u; int v=P[id].v; int t1=getsum(first[u],CS); int t2=getsum(first[v],CS); int t3=getsum(first[u],CD); int t4=getsum(first[v],CD); int tmp=t1+t2-t3-t4; dp[cur]=max(dp[cur],tmp+P[id].val+sum[cur]); } add(first[cur],sum[cur],CS); add(last[cur],-sum[cur],CS); add(first[cur],dp[cur],CD); add(last[cur],-dp[cur],CD); } int main() { int cas; scanf("%d",&cas); for(int cc=1; cc<=cas; cc++) { int n,m; timoflook=numofedg=0; scanf("%d%d",&n,&m); for(int i=0; i<=n; i++) { CS[i*2]=CS[i*2+1]=CD[i*2]=CD[i*2+1]=0; H[i]=0;E[i].clear(); } for(int i=1; i<n; i++) { int u,v; scanf("%d%d",&u,&v); addedg(u,v); } fa[1][0]=1; dfs(1,1,0); for(int i=0; i<m; i++) { scanf("%d%d%d",&P[i].u,&P[i].v,&P[i].val); P[i].lca=getlca(P[i].u,P[i].v); E[P[i].lca].push_back(i); } solve(1,-1); printf("%d ",dp[1]); } return 0; }