树套树(splay套线段树) -AcWing 2476
本来想着用multiset套线段树的,结果一直T。改成常数小的splay才过,写完人都傻了^^
/*
splay套线段树
*/
#include <bits/stdc++.h>
using namespace std;
const int N = 5e4+5;
const int M = 1e7+5;
const int INF = 0x3fffffff;
struct SplayNode{
int s[2],p,v;
int size;
void init(int _v,int _p){
v = _v;
p = _p;
}
};
int n,m,op,L,R,X;
int idx;
int root[N<<2]; // splay num
SplayNode tr[M]; // splay nodes
int arr[N];
inline int ws(int x){
return tr[tr[x].p].s[1] == x;
}
inline void push_up(int x){
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}
inline void rotate(int x){
int y = tr[x].p;
int z = tr[y].p;
int k = ws(x);
// modify z
tr[z].s[ws(y)] = x;
// modify y
tr[y].s[k] = tr[x].s[k^1];
tr[y].p = x;
// modify m
tr[tr[x].s[k^1]].p = y;
// modify x
tr[x].p = z;
tr[x].s[k^1] = y;
push_up(y);
push_up(x);
}
inline void splay(int x,int k,int tr_id){
while(tr[x].p != k){
int y = tr[x].p;
int z = tr[y].p;
if(z != k){
if(ws(x) ^ ws(y)){
rotate(x);
}else{
rotate(y);
}
}
rotate(x);
}
if(!k) root[tr_id] = x;
}
inline int insert(int v,int tr_id){
int u = root[tr_id];
int p = 0;
while(u){
p = u;
u = tr[u].s[v > tr[u].v];
}
u = ++idx;
if(p) tr[p].s[v > tr[p].v] = u;
tr[u].init(v, p);
splay(u, 0, tr_id);
return u;
}
inline int count_less(int v,int tr_id){
int u = root[tr_id], cnt = 0;
while(u){
if(tr[u].v < v){
cnt += tr[tr[u].s[0]].size + 1;
u = tr[u].s[1];
}else{
u = tr[u].s[0];
}
}
return cnt-1; // -INF
}
inline int get_l(int x){
int u = tr[x].s[0];
if(u == 0) return -1;
while(tr[u].s[1]) u = tr[u].s[1];
return u;
}
inline int get_r(int x){
int u = tr[x].s[1];
if(u == 0) return -1;
while(tr[u].s[0]) u = tr[u].s[0];
return u;
}
inline void update(int x,int v,int tr_id){
int u = root[tr_id];
while(u){
if(tr[u].v == x) break;
if(tr[u].v < x) u = tr[u].s[1];
else u = tr[u].s[0];
}
splay(u, 0, tr_id);
int l = get_l(u);
int r = get_r(u);
splay(l, 0, tr_id);
splay(r, l, tr_id);
tr[r].s[0] = 0;
push_up(r);
push_up(l);
insert(v,tr_id);
}
int get_pre(int v,int tr_id){
int u = root[tr_id],ans = -INF;
while(u){
if(tr[u].v < v){
ans = max(ans, tr[u].v);
u = tr[u].s[1];
}else{
u = tr[u].s[0];
}
}
return ans;
}
int get_suc(int v,int tr_id){
int u = root[tr_id], ans = INF;
while(u){
if(tr[u].v > v){
ans = min(ans, tr[u].v);
u = tr[u].s[0];
}else{
u = tr[u].s[1];
}
}
return ans;
}
// seg
void build(int l,int r,int rt){
insert(-INF, rt);
insert(INF, rt);
for(int i = l; i <= r; ++i){
insert(arr[i],rt);
}
if(l != r){
int mid = (l+r)>>1;
build(l, mid, rt<<1);
build(mid+1, r, rt<<1|1);
}
}
int query_less(int l,int r,int rt,int ql,int qr,int x){
if(ql <= l && qr >= r){
return count_less(x, rt);
}else{
int mid = (l+r)>>1;
int ans = 0;
if(ql <= mid){
ans += query_less(l, mid, rt<<1, ql, qr, x);
}
if(qr > mid){
ans += query_less(mid+1, r, rt<<1|1, ql, qr, x);
}
return ans;
}
}
int rank_k(int ql,int qr,int k){
int l = 0,r = 1e8,ans = -1;
while(l <= r){
int mid = (l+r)>>1;
if(query_less(1, n, 1, ql, qr, mid) >= k){
r = mid-1;
}else{
ans = mid;
l = mid+1;
}
}
return ans;
}
void modify(int l,int r,int rt,int pos,int v){
update(arr[pos], v, rt);
if(l==r){
arr[pos] = v;
}else{
int mid = (l+r)>>1;
if(pos <= mid){
modify(l, mid, rt<<1, pos, v);
}else{
modify(mid+1, r, rt<<1|1, pos, v);
}
}
}
int query_pre(int l,int r,int rt,int ql,int qr,int v){
if(ql <= l && qr >= r){
return get_pre(v, rt);
}else{
int mid = (l+r)>>1;
int ans = -INF;
if(ql <= mid){
ans = max(ans,query_pre(l, mid, rt<<1, ql, qr, v));
}
if(qr > mid){
ans = max(ans,query_pre(mid+1, r, rt<<1|1, ql, qr, v));
}
return ans;
}
}
int query_suc(int l,int r,int rt,int ql,int qr,int v){
if(ql <= l && qr >= r){
return get_suc(v, rt);
}else{
int mid = (l+r)>>1;
int ans = INF;
if(ql <= mid){
ans = min(ans,query_suc(l, mid, rt<<1, ql, qr, v));
}
if(qr > mid){
ans = min(ans,query_suc(mid+1, r, rt<<1|1, ql, qr, v));
}
return ans;
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; ++i){
scanf("%d",&arr[i]);
}
build(1, n, 1);
while(m--){
scanf("%d",&op);
if(op == 1){
scanf("%d%d%d",&L,&R,&X);
printf("%d
",query_less(1, n, 1, L, R, X)+1);
}else if(op == 2){
scanf("%d%d%d",&L,&R,&X);
printf("%d
",rank_k(L, R, X));
}else if(op == 3){
scanf("%d%d",&L,&X);
modify(1, n, 1, L, X);
}else if(op == 4){
scanf("%d%d%d",&L,&R,&X);
printf("%d
",query_pre(1, n, 1, L, R, X));
}else{
scanf("%d%d%d",&L,&R,&X);
printf("%d
",query_suc(1, n, 1, L, R, X));
}
}
return 0;
}