题解
考虑把式子化一下,因为只有一个式子跟第二棵树有关,所以我们可以考虑把前面的式子化成跟 $ ext{lca}$ 没有关系,即 $frac{1}{2}(dp_u+dp_v+dis(u,v))$ 。因此我们可以利用边分治,每次把两边的点黑白染色,构成虚树,然后做 $ ext{dp}$ 即可。这里要注意 $ ext{lca}$ 要 $O(1)$ 求,虚树构成过程中不能排序,故我们可以一开始就按照第二棵树的dfs排序好,之后分治下去即可。效率: $O(nlogn)$ 。
代码
#include <bits/stdc++.h> #define LL long long using namespace std; const int N=4e5+5,N2=N<<1,N4=N<<2; int n,m,t=1,fa[22][N2],Lg[N2],d[N],e[N],o,rt,su,hd[N2],sz[N2]; int V[N4],W[N4],nx[N4],b[N],id[N],col[N],tp,S[N],h[2][N],c; LL dp[N],Dp[N],sm[N],f[2][N],ans=-2e18; bool vis[N2];vector<int>X[N],Y[N]; void Add(int u,int v,int w){ X[u].push_back(v);Y[u].push_back(w); } void add(int u,int v,int w){ nx[++t]=hd[u];V[hd[u]=t]=v;W[t]=w; } void add(int u,int v){X[u].push_back(v);} void rebuild(int u,int fr){ int x=0,z=X[u].size(); for (int v,w,i=0;i<z;i++){ v=X[u][i];w=Y[u][i]; if (v==fr) continue;dp[v]=dp[u]+w; if (!x) add(u,v,w),add(v,u,w),x=u; else m++,add(x,m,0),add(m,x,0), add(m,v,w),add(v,m,w),x=m; rebuild(v,u); } } void dfs(int u,int fr){ int z=X[u].size(); fa[0][e[u]=++c]=u;b[id[u]=++t]=u; for (int v,w,i=0;i<z;i++){ v=X[u][i];w=Y[u][i]; if (v==fr) continue; Dp[v]=Dp[u]+w;d[v]=d[u]+1; dfs(v,u);fa[0][++c]=u; } } int Min(int u,int v){return d[u]<d[v]?u:v;} int qry(int l,int r){ l=e[l];r=e[r]; if (l>r) swap(l,r);int i=Lg[r-l+1]; return Min(fa[i][l],fa[i][r-(1<<i)+1]); } void Sz(int u,int fr){ sz[u]=1; for (int v,i=hd[u];i;i=nx[i]) if (!vis[i>>1] && (v=V[i])!=fr) Sz(v,u),sz[u]+=sz[v]; } void Rt(int u,int fr){ for (int v,w,i=hd[u];i;i=nx[i]) if (!vis[i>>1] && (v=V[i])!=fr){ w=max(sz[v],o-sz[v]); if (w<su) rt=i,su=w;Rt(v,u); } } void find(int u,int fr,LL w,int cl){ if (u<=n) col[u]=cl,sm[u]=w; for (int v,i=hd[u];i;i=nx[i]) if (!vis[i>>1] && (v=V[i])!=fr) find(V[i],u,w+W[i],cl); } void ins(int u){ if (tp<1){S[++tp]=u;return;} int x=qry(S[tp],u); if (x==S[tp]){S[++tp]=u;return;} while(tp>1 && id[S[tp-1]]>=id[x]) add(S[tp-1],S[tp]),tp--; if (x!=S[tp]) add(x,S[tp]),S[tp]=x; S[++tp]=u; } void get(int u){ int z=X[u].size(); f[0][u]=f[1][u]=-2e18; if (~col[u]) f[col[u]][u]=dp[u]+sm[u]; for (int v,i=0;i<z;i++){ v=X[u][i];get(v); for (int j=0;j<2;j++) ans=max(ans,f[j][u]+f[!j][v]-2ll*Dp[u]); for (int j=0;j<2;j++) f[j][u]=max(f[j][u],f[j][v]); } X[u].clear(); } void solve(int u,int l,int r){ Sz(u,0);o=sz[u];rt=0;su=1e9; Rt(u,0);if (!rt) return; int x=V[rt],y=V[rt^1];vis[rt>>1]=1; find(x,y,0,0);find(y,x,W[rt],1); if (b[l]!=1) ins(1); for (int i=l;i<=r;i++) ins(b[i]); while(tp>1) add(S[tp-1],S[tp]),tp--; tp=0;get(1);int v[2]={0}; for (int w,i=l;i<=r;i++) w=col[b[i]],h[w][++v[w]]=b[i],col[b[i]]=-1; for (int i=0;i<v[0];i++) b[i+l]=h[0][i+1]; for (int i=0;i<v[1];i++) b[r-i]=h[1][v[1]-i]; solve(x,l,l+v[0]-1);solve(y,r-v[1]+1,r); } int main(){ cin>>n;m=n; for (int i=1,x,y,z;i<n;i++) scanf("%d%d%d",&x,&y,&z), Add(x,y,z),Add(y,x,z);rebuild(1,0); for (int i=1;i<=n;i++) X[i].clear(),Y[i].clear(),col[i]=-1;t=0; for (int i=1,x,y,z;i<n;i++) scanf("%d%d%d",&x,&y,&z), Add(x,y,z),Add(y,x,z);dfs(1,0); for (int i=2;i<=c;i++) Lg[i]=Lg[i>>1]+1; for (int i=c;i;i--) for (int j=1;i+(1<<j)<=c+1;j++) fa[j][i]=Min(fa[j-1][i],fa[j-1][i+(1<<(j-1))]); for (int i=1;i<=n;i++) X[i].clear(); solve(1,1,n);ans>>=1; for (int i=1;i<=n;i++) ans=max(ans,dp[i]-Dp[i]); cout<<ans<<endl;return 0; }