解法:用线段树维护区间内的MST
具体合并方法如下:
首先,由于两个都是树结构,我们把区间间的两条边连起来,会出现一个环,我们只需要把环上最长的边cut掉即可
首先,绿色部分可以O(1)计算,我们需要维护的只是红色和蓝色部分,即区间内最左和最右的竖边,以及其左/右横边的Max .
如果我们cut的是是横边,或者虽然cut了竖边但区间内有多条竖边(cut蓝色边),那么大区间的lc(最左竖边再左边横框的Max,即蓝色部分),rc(红色部分)都可以直接用两个小区间的lc,rc更新。
但是如果我们cut的竖边是区间内唯一的竖边,如红色边,那么我们新区间的lc需用蓝,绿,以及左边横边的Max更新,由于左边区间只有一条竖边,那么该区间的lc,rc即囊括了所有横边,直接用其来更新即可
右边同理。
End.
#include<cstdio> #include<cstdlib> #include<algorithm> #include<cmath> #include<cstring> using namespace std; struct Tree{ int lef,rig,sum,lc,rc,dn; }tr[300011]; int A[60011],B[60011],c[60011]; int n,m,i,sx,sy,tx,ty,w,ans,l,r; Tree Tr; char s[11]; Tree Merge(Tree a,Tree b,int r) { bool pw,cut; int kd,mx; Tree c; mx=0; mx=max(a.rig,mx);mx=max(b.lef,mx); mx=max(a.rc,mx);mx=max(b.lc,mx); mx=max(A[r],mx);mx=max(B[r],mx); pw=false; cut=false; if(mx==a.rig){ if(a.dn==1){ pw=true; kd=1; } cut=true; } if(mx==b.lef){ if(b.dn==1){ pw=true; kd=2; } cut=true; } c.dn=a.dn+b.dn; if(cut)c.dn--; c.sum=a.sum+b.sum+A[r]+B[r]-mx; if(pw==false){ c.lc=a.lc;c.rc=b.rc; c.lef=a.lef;c.rig=b.rig; } else if(kd==1){ c.rc=b.rc;c.rig=b.rig; c.lc=b.lc;c.lc=max(c.lc,a.rc);c.lc=max(c.lc,a.lc); c.lc=max(c.lc,A[r]);c.lc=max(c.lc,B[r]); c.lef=b.lef; } else{ c.lc=a.lc;c.lef=a.lef; c.rc=a.rc;c.rc=max(c.rc,b.lc);c.rc=max(c.rc,b.rc); c.rc=max(c.rc,A[r]);c.rc=max(c.rc,B[r]); c.rig=a.rig; } return c; } void build(int l,int r,int t) { if(l==r){ tr[t].lc=tr[t].rc=0; tr[t].lef=tr[t].rig=c[l]; tr[t].dn=1; tr[t].sum=c[l]; return; } int mid; mid=(l+r)/2; build(l,mid,t+t); build(mid+1,r,t+t+1); tr[t]=Merge(tr[t+t],tr[t+t+1],mid); } void insert(int t,int l,int r,int x) { if(l==r){ tr[t].lef=tr[t].rig=c[l]; tr[t].sum=c[l]; return; } int mid; mid=(l+r)/2; if(x<=mid)insert(t+t,l,mid,x); if(x>mid)insert(t+t+1,mid+1,r,x); tr[t]=Merge(tr[t+t],tr[t+t+1],mid); } Tree ask(int t,int l,int r,int x,int y) { int mid; if(l==x&&r==y)return tr[t]; mid=(l+r)/2; if(y<=mid)return ask(t+t,l,mid,x,y); if(x>mid)return ask(t+t+1,mid+1,r,x,y); if(x<=mid&&y>mid)return Merge(ask(t+t,l,mid,x,mid),ask(t+t+1,mid+1,r,mid+1,y),mid); } int main() { scanf("%d%d",&n,&m); for(i=1;i<n;i++)scanf("%d",&A[i]); for(i=1;i<n;i++)scanf("%d",&B[i]); for(i=1;i<=n;i++)scanf("%d",&c[i]); build(1,n,1); for(i=1;i<=m;i++){ scanf("%s",&s); if(s[0]=='C'){ scanf("%d%d%d%d%d",&sx,&sy,&tx,&ty,&w); if(sy==ty)c[sy]=w; else{ if(sy>ty)swap(sy,ty); if(sx==1)A[sy]=w; else B[sy]=w; } insert(1,1,n,sy); } else{ scanf("%d%d",&l,&r); Tr=ask(1,1,n,l,r); ans=Tr.sum; printf("%d ",ans); } } }