题目链接:
题目大意:给出一棵n个节点的树,边有非负边权,并给出m条链,对于每条链有一个代价,要求选出两条有公共边的链使两条链的并的边权和-两条链的代价和最大。
花了一天的时间,终于搞定了这道题,不可否认这真的是一道神题,对思维和代码能力的考察都非常到位。
通过手画或者数据范围的特殊性质都不难发现两条有公共边的链的$LCA$要么相同要么不同(这不废话吗?)
而两条链的并的计算方法也可以按照上面两种情况来分类计算,我们先分别讨论两种情况的计算方法再考虑如何一起考虑两种情况。
一、两条链LCA不同
对于两条链$LCA$不同的情况可以归为以下形式:
我们不妨设黄色链的长度为$len1$,代价为$val1$,$LCA$为$lca1$;蓝色链的长度为$len2$,代价为$val2$,$LCA$为$lca2$;红点深度为$dep1$。
那么答案可以记为:$ans=len1+len2-val1-val2-dep1+max(dep(lca1),dep(lca2))$
可以发现两条链的交必定是直上直下的并且两条链位于红点子树内的端点必定处于红点不同的子树内。
那么对于每个红点i,我们维护:
$f[i][j]$表示链一端点在$i$子树中,且链的$LCA$深度为$j$的链的长度-代价的最大值
$g[i][j]$表示链一端点在$i$子树中,且链的$LCA$深度为$j$的链的长度-代价+$LCA$的深度的最大值
这两个数组可以用线段树来维护,对于每一个点动态开点建一棵线段树,剩下的只需要从下往上在树上进行线段树合并即可。
在合并时,例如$x$为$y$的父节点,我们将$y$的线段树合并到$x$的线段树上。
因为$x$的线段树上当前保存的是$y$之前的$x$的子树中的信息,所以合并保证两条链的端点位于$x$不同的子树当中。
在合并时每当遍历到一个点我们用$x$线段树中这个点的左子树中$f$的最大值与$y$线段树中对应点右子树中$g$的最大值或$x$线段树中这个点右子树中$g$的最大值与$y$线段树中对应点左子树中$f$的最大值更新答案。
因为右子树代表的深度要比左子树代表的深度大,所以对于左子树只会用$f$来更新答案,右子树只会用$g$来更新答案。
那么如何保证合并的两条链有公共边呢?
我们对于树上的一个点$x$,假设它的深度为$dep$,那么在它合并完所有子树的线段树之后,将它的线段树中$[dep-1,n]$这一部分删除掉,因为回溯到的点的深度是$dep-1$,所以要保证线段树中留下来可合并的信息深度小于$dep-1$。
时间复杂度为$O(mlogn)$
二、两条链LCA相同
如图所示为两条链LCA相同的情况:
可以发现这种情况下的答案就是(两条链长之和+蓝点间距离+绿点间距离)/2-两条链代价之和。
我们同样设红色链长为$len1$,代价为$val1$,红链的蓝点深度为$dep1$;蓝色链长为$len2$,代价为$val2$,蓝链的蓝点深度为$dep2$,红点深度为$dep$,两绿点间距离为$dis$
那么$ans=frac{len1-2*val1+dep1+len2-2*val2+dep2+dis-dep}{2}$
我们将红链的绿点权值设为$len1-2*val1+dep1$,蓝链的绿点权值设为$len2-2*val2+dep2$
再将每个绿点映射给对应的蓝点,定义两个绿点间的距离为两点间路径边权和+两点点权
这样我们只需要对蓝点做树形DP即可,对于每个蓝点维护这个点子树中所有点对应的绿点形成的树的直径的两端点。
将每个子节点的信息合并到父节点上并更新答案,注意更新答案的直径两端点要在合并的两个点信息中分别选一个点。
这样才能保证需要减掉的$dep$为当前合并的这个父节点的深度。
而这里的直径即为上述定义的两点间距离的最大值,这种直径同样符合可合并的那个结论,即:
将两棵树合并为一棵树时,新树的直径两端点为原来两棵树的直径的那四个端点中的两个。(证明参见树的直径及其性质与证明)
因为这种情况要求两条链的$LCA$相同,所以我们记录一下以每个点为$LCA$的链有哪些,将这些链的端点建虚树做上述的树形DP即可。
因为要保证两条链有公共边,所以如果父节点是当前枚举的这个$LCA$时就不合并或更新答案。
时间复杂度为$O(mlogn)$
三、排除两种情况之间的影响
这个问题想了好久,后来才发现其实保证这一点很简单。
对于两条链$LCA$相同的情况,因为我们是枚举每个点为$LCA$的情况来做的,所以保证了当前所有链的$LCA$都相同。
对于两条链$LCA$不同的情况,因为我们线段树合并时只用一个点的左子节点和右子节点来更新答案,所以保证了用于更新答案的两个链的$LCA$深度不同。
值得注意的是,因为我们无法知道两种情况中每个端点是属于上述的蓝点还是红点亦或绿点,所以对于一条链的两个端点都要存储信息,并且这样并不会影响答案的正确性。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long #define INF (1ll<<60) #define pr pair<int,ll> #define par pair<int,pr> using namespace std; char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int rd() {int x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;} ll rd2() {ll x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;} ll z; int T; ll ans; int n,m; int x,y; int tot; int dfn; int cnt; int num; int top; ll d[100010]; int s[100010]; int t[100010]; int st[100010]; int lg[200010]; int to[200010]; ll val[200010]; int vis[100010]; int dep[100010]; int ls[4000010]; int rs[4000010]; ll mx1[4000010]; ll mx2[4000010]; int next[200010]; int head[100010]; int root[100010]; int f[200010][19]; vector<int>q[100010]; vector<par>v[100010]; struct miku { int u[2]; ll v[2]; ll len; miku() { u[0]=u[1]=0; v[0]=v[1]=-INF; len=-INF; } miku(int rt,ll val) { u[0]=rt; v[0]=val; u[1]=0; v[1]=-INF; len=-INF; } }tr[100010]; bool cmp(par x,par y) { return s[x.first]<s[y.first]; } void add(int x,int y,ll z) { next[++tot]=head[x]; head[x]=tot; to[tot]=y; val[tot]=z; } void dfs(int x) { f[++dfn][0]=x; s[x]=dfn; for(int i=head[x];i;i=next[i]) { dep[to[i]]=dep[x]+1; d[to[i]]=d[x]+val[i]; dfs(to[i]); f[++dfn][0]=x; } } int mn(int x,int y) { return dep[x]<dep[y]?x:y; } void ST() { for(int i=2;i<=dfn;i++) { lg[i]=lg[i>>1]+1; } for(int j=1;j<=18;j++) { for(int i=1;i+(1<<j)-1<=dfn;i++) { f[i][j]=mn(f[i][j-1],f[i+(1<<(j-1))][j-1]); } } } int lca(int x,int y) { x=s[x]; y=s[y]; if(x>y) { swap(x,y); } int len=lg[y-x+1]; return mn(f[x][len],f[y-(1<<len)+1][len]); } ll dis(int x,int y) { return d[x]+d[y]-(d[lca(x,y)]<<1); } int build() { int rt=++cnt; ls[rt]=rs[rt]=0; mx1[rt]=mx2[rt]=-INF; return rt; } void change(int &rt,int l,int r,int k,ll v1,ll v2,ll deep) { if(!rt) { rt=build(); } mx1[rt]=max(mx1[rt],v1); mx2[rt]=max(mx2[rt],v2); if(l==r) { return ; } int mid=(l+r)>>1; if(k<=mid) { ans=max(ans,v1+mx2[rs[rt]]-deep); change(ls[rt],l,mid,k,v1,v2,deep); } else { ans=max(ans,v2+mx1[ls[rt]]-deep); change(rs[rt],mid+1,r,k,v1,v2,deep); } } void pushup(int rt) { mx1[rt]=max(mx1[ls[rt]],mx1[rs[rt]]); mx2[rt]=max(mx2[ls[rt]],mx2[rs[rt]]); } void cut(int &rt,int l,int r,int k) { if(!rt) { return ; } if(l==r) { rt=0; return ; } int mid=(l+r)>>1; if(k<=mid) { rs[rt]=0; cut(ls[rt],l,mid,k); } else { cut(rs[rt],mid+1,r,k); } pushup(rt); } int merge(int x,int y,int l,int r,ll deep) { if(!x||!y) { return x+y; } int mid=(l+r)>>1; if(l==r) { mx1[x]=max(mx1[x],mx1[y]); mx2[x]=max(mx2[x],mx2[y]); return x; } ans=max(ans,mx1[ls[x]]+mx2[rs[y]]-deep); ans=max(ans,mx1[ls[y]]+mx2[rs[x]]-deep); ls[x]=merge(ls[x],ls[y],l,mid,deep); rs[x]=merge(rs[x],rs[y],mid+1,r,deep); pushup(x); return x; } void dsu(int x) { for(int i=head[x];i;i=next[i]) { dsu(to[i]); root[x]=merge(root[x],root[to[i]],1,n,d[x]); } cut(root[x],1,n,dep[x]-1); } miku merge(miku &x,miku &y,ll deep) { miku res; ll value; res=x.len>y.len?x:y; for(int i=0;i<2;i++) { for(int j=0;j<2;j++) { if(x.u[i]&&y.u[j]) { value=dis(x.u[i],y.u[j])+x.v[i]+y.v[j]; ans=max(ans,(value-deep)>>1); if(value>res.len) { res.u[0]=x.u[i]; res.u[1]=y.u[j]; res.v[0]=x.v[i]; res.v[1]=y.v[j]; res.len=value; } } } } return res; } void dsu(int x,int rt) { int size=q[x].size(); for(int i=0;i<size;i++) { int to=q[x][i]; dsu(to,rt); if(x!=rt) { tr[x]=merge(tr[x],tr[to],d[x]*2); } } vis[x]=0; q[x].clear(); } void insert(int x) { int fa=lca(x,st[top]); if(!vis[fa]) { vis[fa]=1; t[++num]=fa; } while(top>1&&dep[st[top-1]]>=dep[fa]) { q[st[top-1]].push_back(st[top]); top--; } if(st[top]!=fa) { q[fa].push_back(st[top]); st[top]=fa; } st[++top]=x; } void work() { n=rd(); for(int i=1;i<=n;i++) { head[i]=0; root[i]=0; v[i].clear(); } ans=-INF; tot=0; dfn=0; cnt=0; dep[1]=1; for(int i=1;i<n;i++) { x=rd();y=rd();z=rd2(); add(x,y,z); } dfs(1); ST(); m=rd(); for(int i=1;i<=m;i++) { x=rd();y=rd();z=rd2(); int fa=lca(x,y); ll value=dis(x,y)-z; if(x!=fa) { change(root[x],1,n,dep[fa],value,value+d[fa],d[x]); v[fa].push_back(make_pair(x,make_pair(y,value+d[x]-z))); } if(y!=fa) { change(root[y],1,n,dep[fa],value,value+d[fa],d[y]); v[fa].push_back(make_pair(y,make_pair(x,value+d[y]-z))); } } dsu(1); for(int i=1;i<=n;i++) { sort(v[i].begin(),v[i].end(),cmp); top=0; num=0; st[++top]=i; vis[i]=1; t[++num]=i; int size=v[i].size(); for(int j=0;j<size;j++) { if(!vis[v[i][j].first]) { vis[v[i][j].first]=1; t[++num]=v[i][j].first; insert(v[i][j].first); } } while(top>1) { q[st[top-1]].push_back(st[top]); top--; } for(int j=1;j<=num;j++) { int now=t[j]; tr[now].u[0]=tr[now].u[1]=0; tr[now].v[0]=tr[now].v[1]=-INF; tr[now].len=-INF; } for(int j=0;j<size;j++) { miku now=miku(v[i][j].second.first,v[i][j].second.second); tr[v[i][j].first]=merge(tr[v[i][j].first],now,d[v[i][j].first]*2); } dsu(i,i); } if(ans<=-1e16) { printf("F "); } else { printf("%lld ",ans); } } int main() { T=rd(); mx1[0]=mx2[0]=-INF; while(T--) { work(); } }