题目:http://www.lydsy.com/JudgeOnline/problem.php?id=3196
这题刷新了我最长代码记录。。
在每个线段树节点上建一个splay,维护该区间内的信息
操作1:只需要用线段树查找区间,用splay查找比k大的有多少个,最后加起来
操作2:二分答案,同样用splay判断比现在的答案大的有多少个,二分到最后一定是比答案大一,所以减一以后就是答案
操作3:相当于线段树的单点修改,维护splay只需要删除再插入
操作4:相当于线段树区间查询,每次找到比k小的最大的数,取max
操作5:类似操作4,每次找到比k大的最小的数,取min
#include<iostream> #include<cstdio> #include<cstring> #include<string> #include<cmath> #include<algorithm> #include<vector> #include<queue> #include<stack> #include<map> #include<set> #define lson i<<1 #define rson i<<1|1 using namespace std; const int N=4e6+5; const int inf=2e9; int a[N],f[N],ch[N][2],key[N],cnt[N],sz[N],root[N]; int n,maxn,tot,ans; inline void splay_clear(int x) { f[x]=ch[x][0]=ch[x][1]=cnt[x]=key[x]=sz[x]=0; } inline int splay_get(int x) { return ch[f[x]][1]==x; } void splay_update(int x) { if (x) { sz[x]=cnt[x]; if (ch[x][0]) sz[x]+=sz[ch[x][0]]; if (ch[x][1]) sz[x]+=sz[ch[x][1]]; } } void splay_rotate(int x) { int fa=f[x],ff=f[fa],kind=splay_get(x); ch[fa][kind]=ch[x][kind^1];f[ch[x][kind^1]]=fa; ch[x][kind^1]=fa;f[fa]=x; f[x]=ff; if (ff) ch[ff][ch[ff][1]==fa]=x; splay_update(fa); splay_update(x); } void splay_splay(int i,int x) { for(int fa;(fa=f[x]);splay_rotate(x)) if (f[fa]) splay_rotate((splay_get(fa)==splay_get(x))?fa:x); root[i]=x; } void splay_insert(int i,int x) { if (!root[i]) { root[i]=++tot; f[tot]=ch[tot][0]=ch[tot][1]=0; cnt[tot]=sz[tot]=1; key[tot]=x; return; } int now=root[i],fa=0; while(1) { if (x==key[now]) { cnt[now]++; splay_update(fa); splay_splay(i,now); return; } fa=now;now=ch[now][key[now]<x]; if (now==0) { ++tot; f[tot]=fa;ch[tot][0]=ch[tot][1]=0; cnt[tot]=sz[tot]=1; key[tot]=x; ch[fa][key[fa]<x]=tot; splay_update(fa); splay_splay(i,tot); return; } } } int splay_findrank(int i,int x) { int now=root[i],tem=0; while(1) { if (!now) return tem; if (key[now]==x) return (ch[now][0]?sz[ch[now][0]]:0)+tem; else if (key[now]<x) { tem+=cnt[now]; if (ch[now][0]) tem+=sz[ch[now][0]]; now=ch[now][1]; } else now=ch[now][0]; } } void splay_find(int i,int x) { int now=root[i]; while(1) { if (key[now]==x) { splay_splay(i,now); return; } else if (x<key[now]) now=ch[now][0]; else now=ch[now][1]; } } int splay_pre(int i) { int now=root[i]; now=ch[now][0]; while(ch[now][1]) now=ch[now][1]; return now; } void splay_del(int i) { int now=root[i]; if (cnt[now]>1) { cnt[now]--; splay_update(now); return; } if (!ch[now][0]&&!ch[now][1]) { splay_clear(now); root[i]=0; return; } if (!ch[now][0]) { int oldroot=now; root[i]=ch[oldroot][1]; f[root[i]]=0; splay_clear(oldroot); return; } if (!ch[now][1]) { int oldroot=now; root[i]=ch[oldroot][0]; f[root[i]]=0; splay_clear(oldroot); return; } int leftbig=splay_pre(i),oldroot=now; splay_splay(i,leftbig); ch[root[i]][1]=ch[oldroot][1]; f[ch[oldroot][1]]=root[i]; splay_clear(oldroot); splay_update(root[i]); } void splay_findpre(int i,int x) { int now=root[i]; while(now) { if (key[now]<x) { if (ans<key[now]) ans=key[now]; now=ch[now][1]; } else now=ch[now][0]; } } void splay_findnext(int i,int x) { int now=root[i]; while(now) { if (key[now]>x) { if (ans>key[now]) ans=key[now]; now=ch[now][0]; } else now=ch[now][1]; } } void seg_insert(int i,int l,int r,int x,int v) { splay_insert(i,v); if (l==r) return; int mid=(l+r)>>1; if (x<=mid) seg_insert(lson,l,mid,x,v); else seg_insert(rson,mid+1,r,x,v); } void seg_findrank(int i,int l,int r,int L,int R,int k) { if (L<=l&&r<=R) { ans+=splay_findrank(i,k); return; } int mid=(l+r)>>1; if (L<=mid) seg_findrank(lson,l,mid,L,R,k); if (R>mid) seg_findrank(rson,mid+1,r,L,R,k); } void seg_change(int i,int l,int r,int pos,int k) { splay_find(i,a[pos]); splay_del(i); splay_insert(i,k); if (l==r) return; int mid=(l+r)>>1; if (pos<=mid) seg_change(lson,l,mid,pos,k); else seg_change(rson,mid+1,r,pos,k); } void seg_findpre(int i,int l,int r,int L,int R,int k) { if (L<=l&&r<=R) { splay_findpre(i,k); return; } int mid=(l+r)>>1; if (L<=mid) seg_findpre(lson,l,mid,L,R,k); if (R>mid) seg_findpre(rson,mid+1,r,L,R,k); } void seg_findnext(int i,int l,int r,int L,int R,int k) { if (L<=l&&r<=R) { splay_findnext(i,k); return; } int mid=(l+r)>>1; if (L<=mid) seg_findnext(lson,l,mid,L,R,k); if (R>mid) seg_findnext(rson,mid+1,r,L,R,k); } int main() { int m,t; scanf("%d%d",&n,&m); maxn=0;tot=0; for(int i=1;i<=n;i++) { scanf("%d",&a[i]); maxn=max(maxn,a[i]); seg_insert(1,1,n,i,a[i]); } while(m--) { scanf("%d",&t); int l,r,k,pos,head,tail; switch(t) { case 1: scanf("%d%d%d",&l,&r,&k); ans=0; seg_findrank(1,1,n,l,r,k); printf("%d ",ans+1); break; case 2: scanf("%d%d%d",&l,&r,&k); head=0;tail=maxn+1; while(head<tail) { int mid=(head+tail)>>1; ans=0; seg_findrank(1,1,n,l,r,mid); if (ans<k) head=mid+1; else tail=mid; } printf("%d ",head-1); break; case 3: scanf("%d%d",&pos,&k); seg_change(1,1,n,pos,k); a[pos]=k; maxn=max(maxn,k); break; case 4: scanf("%d%d%d",&l,&r,&k); ans=0; seg_findpre(1,1,n,l,r,k); printf("%d ",ans); break; case 5: scanf("%d%d%d",&l,&r,&k); ans=inf; seg_findnext(1,1,n,l,r,k); printf("%d ",ans); break; } } return 0; }