原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ347.html
题意
有三棵树,边有边权。
对于所有点对 (x,y) 求在三棵树上 x 到 y 的距离之和 的最大值。
点数 <=100000
题解
我自闭了。
在此之前,我没写过边分治,只写过一次虚树。
我自闭了。
一棵树怎么做?
树的直径。
两棵树怎么做?
有一个定理:从点集A中的点到点集B中的点的最长路径的两端点一定属于 点集A中最长路两端点和点集B中最长路两端点 构成的集合。
首先,在第一棵树上,我们求出每一个点的深度(即到根距离),记点 x 的深度为 D[x] 。则点 x 到 点 y 在这两棵树上的距离之和为 D[x]+D[y]-2D[LCA(x,y)] + dis(x,y) 。(其中dis(x,y) 代表在第二棵树上的距离)
我们考虑在第二棵树上,对于任意一个点 x ,新建节点 x' , x' 仅和 x 有一条权值为 D[x] 的边。那么 x 到 y 的距离就是 dis(x',y') - 2D[LCA(x,y)] 。我们考虑对第一棵树进行dfs,对于每一个节点,依次将子树点集中的最远点对合并到父亲上来(这里要用到之前说的定理),顺便更新答案即可。
那么三棵树呢?
给多出来的那棵树边分治一下,将第三棵树中的边 (x,x') 的权值修改成 D[x] + 在第三棵树中 x 到边分中心的距离。注意边分的时候作为边分中心的那条边的权值不要忘记加上。
对于第三棵树中分治出来的点集,我们需要在第二棵树上建虚树。
实际写代码的时候,节点 x' 以及边 (x,x') 都是不用加上的。
代码
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=100005*2; const LL INF=5e17; int n,m; struct Gragh{ static const int M=N*2; int cnt,y[M],nxt[M],fst[N]; LL z[M]; void clear(){ cnt=1; memset(fst,0,sizeof fst); } void add(int a,int b,LL c){ y[++cnt]=b,z[cnt]=c,nxt[cnt]=fst[a],fst[a]=cnt; } }g[4],G; #define For(i,y,g,x) for (LL i=g.fst[x],y=g.y[i];i;i=g.nxt[i],y=g.y[i]) #define Foryx(y,g,x) For(_index,y,g,x) #define Fory(y,g) Foryx(y,g,x) #define Forg(g) Fory(y,g) int fa[4][N][17],depth[4][N]; LL len[4][N],addv[N]; int I[N],O[N],_Time=0; void dfs1(int _id,int x,int pre,int d,LL L){ fa[_id][x][0]=pre; for (int i=1;i<17;i++) fa[_id][x][i]=fa[_id][fa[_id][x][i-1]][i-1]; depth[_id][x]=d,len[_id][x]=L; Forg(g[_id]) if (y!=pre) dfs1(_id,y,x,d+1,L+g[_id].z[_index]); } void dfs2(int x,int pre){ I[x]=++_Time; Forg(g[2]) if (y!=pre) dfs2(y,x); O[x]=_Time; } int LCA(int id,int x,int y){ if (depth[id][x]<depth[id][y]) swap(x,y); for (int i=16;i>=0;i--) if (depth[id][x]-(1<<i)>=depth[id][y]) x=fa[id][x][i]; if (x==y) return x; for (int i=16;i>=0;i--) if (fa[id][x][i]!=fa[id][y][i]) x=fa[id][x][i],y=fa[id][y][i]; return fa[id][x][0]; } LL Dis(int id,int x,int y){ return len[id][x]+len[id][y]-len[id][LCA(id,x,y)]*2; } void dfs3(Gragh &g1,Gragh &g2,int x,int pre){ int p=x; Forg(g1) if (y!=pre){ m++; g2.add(p,m,0),g2.add(m,p,0); g2.add(y,m,g1.z[_index]); g2.add(m,y,g1.z[_index]); p=m; } Forg(g1) if (y!=pre) dfs3(g1,g2,y,x); } void rebuild(Gragh &g,Gragh &res){ res.clear(),m=n; dfs3(g,res,1,0); } int vis[N],size[N],Size,ckv[N]; int RT,RTF,Time=0; LL LEN; vector <int> node; void dfs4(int x,int pre){ if (x<=n) node.push_back(x); size[x]=1; Forg(g[0]) if (y!=pre&&!vis[y]) dfs4(y,x),size[x]+=size[y]; ckv[x]=max(size[x],Size-size[x]); if (!RT||ckv[x]<ckv[RT]) RT=x,RTF=pre; } int tag[N]; LL ans; void dfs5(int x,int pre,int Tag,LL D){ tag[x]=Tag,addv[x]=D; Forg(g[0]) if (y!=pre&&!vis[y]) dfs5(y,x,Tag,D+g[0].z[_index]); } LL SpDis(int a,int b){ if (!a||!b) return -INF; return Dis(3,a,b)+addv[a]+addv[b]; } bool cmpI(int x,int y){ return I[x]<I[y]; } bool isfather(int x,int y){//x is father of y return I[x]<=I[y]&&I[y]<=O[x]; } struct road{ int x,y; LL len; road(){} road(int _x,int _y,LL _len){ x=_x,y=_y,len=_len; } }r0[N],r1[N]; void Updr(road &a,road b){ LL d1=a.len,d2=SpDis(a.x,b.x),d3=SpDis(a.x,b.y); LL d4=b.len,d5=SpDis(a.y,b.x),d6=SpDis(a.y,b.y); LL mx=max(d1,max(d2,max(d3,max(d4,max(d5,d6))))); if (d1==mx) a=road(a.x,a.y,d1); else if (d2==mx) a=road(a.x,b.x,d2); else if (d3==mx) a=road(a.x,b.y,d3); else if (d4==mx) a=road(b.x,b.y,d4); else if (d5==mx) a=road(a.y,b.x,d5); else a=road(a.y,b.y,d6); } void Upda(road a,road b,LL Add){ ans=max(ans,SpDis(a.x,b.x)+Add); ans=max(ans,SpDis(a.x,b.y)+Add); ans=max(ans,SpDis(a.y,b.x)+Add); ans=max(ans,SpDis(a.y,b.y)+Add); } void solve3(int x,int pre,LL D,LL Addlen){ addv[x]+=D,r0[x]=r1[x]=road(0,0,-INF); if (tag[x]==-Time) Updr(r0[x],road(x,x,0)); if (tag[x]==Time) Updr(r1[x],road(x,x,0)); Forg(G){ solve3(y,x,D+G.z[_index],Addlen); Upda(r0[x],r1[y],Addlen-D*2); Upda(r1[x],r0[y],Addlen-D*2); Updr(r0[x],r0[y]); Updr(r1[x],r1[y]); } } void make_tree(LL Addlen){ static int st[N],top; top=st[0]=0; if (abs(tag[1])!=Time) node.push_back(1); sort(node.begin(),node.end(),cmpI); G.cnt=1; for (auto x : node){ if (top>0&&!isfather(st[top],x)){ while (top>1&&!isfather(st[top-1],x)) G.add(st[top-1],st[top],len[2][st[top]]-len[2][st[top-1]]),top--; int y=LCA(2,st[top],x); if (y!=st[top-1]) G.fst[y]=0; G.add(y,st[top],len[2][st[top]]-len[2][y]); top--; if (y!=st[top]) st[++top]=y; } G.fst[x]=0,st[++top]=x; } for (int i=1;i<top;i++) G.add(st[i],st[i+1],len[2][st[i+1]]-len[2][st[i]]); solve3(1,0,0,Addlen); } void solve(int _x){ if (Size==1) return; Time++,RT=RTF=0; node.clear(); dfs4(_x,0); int x=RT,y=RTF; dfs5(x,y,Time,0); dfs5(y,x,-Time,0); for (int i=g[0].fst[x];i;i=g[0].nxt[i]) if (g[0].y[i]==y){ LEN=g[0].z[i]; break; } make_tree(LEN); int sz1=size[x],sz2=Size-sz1; vis[y]=1,Size=sz1,solve(x),vis[y]=0; vis[x]=1,Size=sz2,solve(y),vis[x]=0; } int main(){ n=read(); for (int t=1;t<=3;t++){ g[t].clear(); for (int i=1;i<n;i++){ int x=read(),y=read(); LL z=read(); g[t].add(x,y,z); g[t].add(y,x,z); } } dfs1(2,1,0,0,0); dfs1(3,1,0,0,0); dfs2(1,0); clr(addv),clr(vis); rebuild(g[1],g[0]); Size=m,ans=0; solve(1); cout<<ans<<endl; return 0; }