将选取的$A$看成左括号,$B$看成右括号,那么答案是一个合法的括号序列。
那么只要重复取出$k$对价值最小的左右括号,保证每时每刻都是一个合法的括号序列即可。
将$($看成$1$,$)$看成$-1$,设$s[]$为前缀和。
如果当前取出的是$()$,那么对前缀和的影响为$[A,B-1]$区间加$1$。
如果当前取出的是$)($,那么对前缀和的影响为$[B,A-1]$区间减$1$,所以这种情况需要满足区间$s$的最小值不为$0$。
考虑用线段树维护这个序列,线段树上每个节点维护以下信息:
va:$Aleq B$情况的最优解。
vb:$A>B$情况的最优解,且满足$[A,B-1]$的区间$s$最小值大于当前区间的$s$最小值。
vc:$A>B$情况的最优解。
aa:区间内代价最小的$A$。
ab:区间内代价最小的$B$。
ba:区间内代价最小的$A$,满足$[st,A-1]$的区间$s$最小值大于区间$s$最小值。
bb:区间内代价最小的$B$,满足$[B,en]$的区间$s$最小值大于区间$s$最小值。
vm:区间$s$最小值。
tag:区间增量标记。
为了方便维护,可以考虑增加第$0$项,$A[0]=B[0]=inf$。
那么$[0,n]$区间的$vb$必定满足区间最小值不为$0$,然后贪心选取$k$次即可求出最优解。
时间复杂度$O(klog n)$。
#include<cstdio> const int N=500010,M=1050000,inf=1000000010; long long ans; int n,k,i,j,A[N],B[N],aa[M],ab[M],ba[M],bb[M],vm[M],tag[M]; struct P{ int x,y; P(){} P(int _x,int _y){x=_x,y=_y;} P operator+(const P&b){return A[x]+B[y]<A[b.x]+B[b.y]?*this:b;} }va[M],vb[M],vc[M],t; inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';} inline void add1(int x,int p){vm[x]+=p;tag[x]+=p;} inline void pb(int x){if(tag[x])add1(x<<1,tag[x]),add1(x<<1|1,tag[x]),tag[x]=0;} inline void up(int x){ int l=x<<1,r=l|1; va[x]=va[l]+va[r]+P(aa[l],ab[r]); vc[x]=vc[l]+vc[r]+P(aa[r],ab[l]); vb[x]=vb[l]+vb[r]; aa[x]=A[aa[l]]<A[aa[r]]?aa[l]:aa[r]; ab[x]=B[ab[l]]<B[ab[r]]?ab[l]:ab[r]; if(vm[l]<vm[r]){ vb[x]=vb[x]+vc[r]+P(aa[r],bb[l]); ba[x]=ba[l]; bb[x]=B[ab[r]]<B[bb[l]]?ab[r]:bb[l]; vm[x]=vm[l]; } if(vm[l]>vm[r]){ vb[x]=vb[x]+vc[l]+P(ba[r],ab[l]); ba[x]=A[aa[l]]<A[ba[r]]?aa[l]:ba[r]; bb[x]=bb[r]; vm[x]=vm[r]; } if(vm[l]==vm[r]){ vb[x]=vb[x]+P(ba[r],bb[l]); ba[x]=ba[l]; bb[x]=bb[r]; vm[x]=vm[l]; } } void build(int x,int a,int b){ if(a==b){ va[x]=vc[x]=P(a,a),vb[x]=P(0,0); aa[x]=ab[x]=ba[x]=a; return; } int mid=(a+b)>>1; build(x<<1,a,mid),build(x<<1|1,mid+1,b),up(x); } void add(int x,int a,int b,int c,int d,int p){ if(c<=a&&b<=d){add1(x,p);return;} pb(x); int mid=(a+b)>>1; if(c<=mid)add(x<<1,a,mid,c,d,p); if(d>mid)add(x<<1|1,mid+1,b,c,d,p); up(x); } void change(int x,int a,int b,int c){ if(a==b)return; pb(x); int mid=(a+b)>>1; if(c<=mid)change(x<<1,a,mid,c);else change(x<<1|1,mid+1,b,c); up(x); } int main(){ read(n),read(k); for(i=1;i<=n;i++)read(A[i]); for(i=1;i<=n;i++)read(B[i]); A[0]=B[0]=inf; build(1,0,n); while(k--){ t=va[1]+vb[1],i=t.x,j=t.y,ans+=A[i]+B[j]; if(i<j)add(1,0,n,i,j-1,1); if(i>j)add(1,0,n,j,i-1,-1); A[i]=inf,change(1,0,n,i); B[j]=inf,change(1,0,n,j); } return printf("%lld",ans),0; }