题目:https://loj.ac/problem/2339
两棵树的话,可以用 CTSC2018 暴力写挂的方法,边分治+虚树。O(nlogn)。
考虑怎么在这个方法上再加一棵树。发现很难弄。
看了看题解,发现两棵树还有别的做法。
就是要最大化 d1[ x ] + d2[ x ] + d1[ y ] + d2[ y ] - 2*d1[ lca1(x,y) ] - 2*d2[ lca2(x,y) ] ,考虑在第一棵树 T1 上 dfs 地枚举 lca1 ,那么考虑的答案就是 T1 上在当前点 cr 的不同子树里的 x 和 y 。
考虑 cr 的之前子树 v1 和当前子树 v2 怎么合并。 v1 和 v2 都记录着自己子树里的答案的两个点 x 和 y 。
似乎根据树的直径证明的类似方法可以得知 cr 的 x 就是 v1 和 v2 的 x 中的一个, cr 的 y 就是 v1 和 v2 的 y 中的一个。
所以把两个 x 和两个 y 组合一下,看看谁的 d1[ x ] + d2[ x ] + d1[ y ] + d2[ y ] - 2*d2[ lca2(x,y) ] 最小,谁就是 cr 的 x 和 y 。
做完 cr 之后,因为要换 lca1 了,所以先贡献一下答案,就是把 cr 记录的 x 和 y 按上面要最大化的那个式子贡献给答案。
只要 RMQ 求 lca 就可以 O(n) 。
所以三棵树就是在这个两棵树的做法上给第三棵树套一个边分治。
就是在当前边分治的情况下,枚举 lca1 ,式子变成最大化 d1[ x ] + d2[ x ] + d3[ x ] + d1[ y ] + d2[ y ] + d2[ y ] - 2*d1[ lca1(x,y) ] - 2*d2[ lca2(x,y) ] + tw,其中 d1[ ] , d2[ ] 是在树 T1 和 T2 上的带权深度,d3[ ] 是在 T3 上到分治中心边的距离, tw 是分治中心边的权值。
枚举 lca1 之后,不仅 x 和 y 不能在 T1 上 lca1 的同一棵子树中,且 x 和 y 还得分别是 T3 的分治中心两边的点,所以 T2 上 DP 的时候,每个点要记两对 ( x , y ) ,表示 T3 两边的直径。
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define ll long long #define pil pair<int,ll> #define pb push_back #define mkp make_pair using namespace std; ll rdn() { ll ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } ll Mx(ll a,ll b){return a>b?a:b;} ll Mn(ll a,ll b){return a<b?a:b;} const int N=2e5+5,K=17; const ll INF=1e18; int n,tn,hd[N],xnt=1,to[N<<1],nxt[N<<1];ll w[N<<1]; int siz[N],mn,Rt,lx[N]; vector<pil> vt[N]; ll d1[N],d2[N],d3[N],ans; namespace T3{ int hd[N],xnt,to[N<<1],nxt[N<<1];ll w[N<<1]; int dep[N],bg[N],en[N],tim,st[N][K],lg[N],bin[K+5]; void add(int x,int y,ll z) { to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z; to[++xnt]=x;nxt[xnt]=hd[y];hd[y]=xnt;w[xnt]=z; } void ini_dfs(int cr,int fa) { bg[cr]=++tim; st[tim][0]=cr; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dep[v]=dep[cr]+1; d3[v]=d3[cr]+w[i]; ini_dfs(v,cr); st[++tim][0]=cr;///// } en[cr]=tim; } void init() { ll z; for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),z=rdn(),add(u,v,z); ini_dfs(1,0); int tn=n<<1; for(int i=2;i<=tn;i++)lg[i]=lg[i>>1]+1; bin[0]=1;for(int i=1;i<=lg[tn];i++)bin[i]=bin[i-1]<<1; for(int t=1;t<=lg[tn];t++) for(int i=1;i+bin[t]-1<=tn;i++) { int u=st[i][t-1], v=st[i+bin[t-1]][t-1]; st[i][t]=(dep[u]<dep[v]?u:v); } } int get_lca(int x,int y) { if(bg[x]>bg[y])swap(x,y); int d=lg[en[y]-bg[x]+1]; int c1=st[bg[x]][d], c2=st[en[y]-bin[d]+1][d]; int ret=dep[c1]<dep[c2]?c1:c2; return ret; } } namespace T2{ int hd[N],xnt,to[N<<1],nxt[N<<1];ll w[N<<1]; int tim,a[N],ta[N],lca[N],tlca[N],dep[N],sta[N],top; struct Node{ int x,y;ll w; Node(int x=0,int y=0,ll w=0):x(x),y(y),w(w) {} }dp[N][2]; void add(int x,int y,ll z) { to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z; to[++xnt]=x;nxt[xnt]=hd[y];hd[y]=xnt;w[xnt]=z; } void ini_dfs(int cr,int fa) { while(top&&dep[sta[top]]>=dep[cr])top--; a[++tim]=cr; lca[tim]=sta[top]; sta[++top]=cr; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dep[v]=dep[cr]+1; d2[v]=d2[cr]+w[i]; ini_dfs(v,cr); } } void init() { ll z; for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),z=rdn(),add(u,v,z); ini_dfs(1,0); } ll calc(int x,int y) { ll ret=d1[x]+d1[y]+d2[x]+d2[y]+d3[x]+d3[y]; int tmp=T3::get_lca(x,y); ret-=(d3[tmp]<<1ll); return ret; } Node operator+ (const Node &a,const Node &b) { int x1=a.x,y1=a.y,x2=b.x,y2=b.y; Node ret=Node(0,0,-1); ll tmp; if(x1&&x2) { tmp=calc(x1,x2); if(tmp>ret.w)ret=Node(x1,x2,tmp);} if(x1&&y2) { tmp=calc(x1,y2); if(tmp>ret.w)ret=Node(x1,y2,tmp);} if(y1&&x2) { tmp=calc(y1,x2); if(tmp>ret.w)ret=Node(y1,x2,tmp);} if(y1&&y2) { tmp=calc(y1,y2); if(tmp>ret.w)ret=Node(y1,y2,tmp);} return ret; } Node mx(Node a,Node b){ return a.w>b.w?a:b;} void link(int cr,int v,ll tw) { ll tmp=d2[cr]<<1ll; ans=Mx(ans,(dp[cr][0]+dp[v][1]).w+tw-tmp); ans=Mx(ans,(dp[cr][1]+dp[v][0]).w+tw-tmp); dp[cr][0]=mx(dp[cr][0],mx(dp[v][0],dp[cr][0]+dp[v][0])); dp[cr][1]=mx(dp[cr][1],mx(dp[v][1],dp[cr][1]+dp[v][1])); dp[v][0]=dp[v][1]=Node(0,0,-1); } int solve(int l,int r,ll tw) { sta[top=1]=a[l]; for(int i=l+1;i<=r;i++) { int lm=dep[lca[i]]; while(top&&dep[sta[top]]>lm) { if(dep[sta[top-1]]>lm)link(sta[top-1],sta[top],tw); else link(lca[i],sta[top],tw); top--; } if(sta[top]!=lca[i])sta[++top]=lca[i]; sta[++top]=a[i]; } for(int i=top-1;i;i--)link(sta[i],sta[i+1],tw); dp[sta[1]][0]=dp[sta[1]][1]=Node(0,0,-1); int mid=l-1; for(int i=l,tl=0;i<=r;i++) { if(!tl||dep[lca[i]]<dep[tl])tl=lca[i]; if(!lx[a[i]]) ta[++mid]=a[i], tlca[mid]=tl, tl=0; } int ret=mid; for(int i=l,tl=0;i<=r;i++) { if(!tl||dep[lca[i]]<dep[tl])tl=lca[i]; if(lx[a[i]]) ta[++mid]=a[i], tlca[mid]=tl, tl=0; } for(int i=l;i<=r;i++)a[i]=ta[i]; for(int i=l;i<=r;i++)lca[i]=tlca[i]; return ret; } } void add(int x,int y,ll z) { to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z; to[++xnt]=x;nxt[xnt]=hd[y];hd[y]=xnt;w[xnt]=z; } void del_ed(int x,int y) { if(to[hd[x]]==y)hd[x]=nxt[hd[x]]; else { for(int i=hd[x],pr;i;pr=i,i=nxt[i]) if(to[i]==y){nxt[pr]=nxt[i];break;} } if(to[hd[y]]==x)hd[y]=nxt[hd[y]]; else { for(int i=hd[y],pr;i;pr=i,i=nxt[i]) if(to[i]==x){nxt[pr]=nxt[i];break;} } } void Rbuild(int cr,int fa) { for(int i=0,lst=0,lm=vt[cr].size();i<lm;i++) { int v=vt[cr][i].first;ll z=vt[cr][i].second; if(v==fa)continue; if(!lst)add(cr,v,z), lst=cr; else{ tn++; add(lst,tn,0); add(tn,v,z); lst=tn;} } for(int i=0,v,lm=vt[cr].size();i<lm;i++) if((v=vt[cr][i].first)!=fa) Rbuild(v,cr); } void get_rt(int cr,int fa,int s) { siz[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { get_rt(v,cr,s); siz[cr]+=siz[v]; int mx=Mx(siz[v],s-siz[v]); if(mx<mn)mn=mx,Rt=i; } } void dfs(int cr,int fa,ll lj,bool fx) { d1[cr]=lj; lx[cr]=fx; T2::dp[cr][fx]=T2::Node(cr,cr,0); T2::dp[cr][!fx]=T2::Node(0,0,-1); for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) dfs(v,cr,lj+w[i],fx); } void solve(int cr,int s,int l,int r) { int u=to[cr^1], v=to[cr]; del_ed(u,v); dfs(u,0,0,0); dfs(v,0,0,1); ll tw=w[cr]; int mid=T2::solve(l,r,tw); int ts=siz[v]; if(ts>1){mn=N;get_rt(v,0,ts);solve(Rt,ts,mid+1,r);} ts=s-ts; if(ts>1){mn=N;get_rt(u,0,ts);solve(Rt,ts,l,mid);} } int main() { n=rdn();ll z; for(int i=1,u,v;i<n;i++) { u=rdn();v=rdn();z=rdn(); vt[u].pb(mkp(v,z)); vt[v].pb(mkp(u,z)); } T2::init(); T3::init(); tn=n; Rbuild(1,0); mn=N;get_rt(1,0,tn);solve(Rt,tn,1,n); printf("%lld ",ans); return 0; }