(color{#0066ff}{ 题目描述 })
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
- 查询k在区间内的排名
- 查询区间内排名为k的值
- 修改某一位值上的数值
- 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
- 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)
(color{#0066ff}{输入格式})
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
(color{#0066ff}{输出格式})
对于操作1,2,4,5各输出一行,表示查询结果
(color{#0066ff}{输入样例})
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
(color{#0066ff}{输出样例})
2
4
3
4
9
(color{#0066ff}{数据范围与提示})
时空限制:2s,128M
(n,m leq 5cdot {10}^4)保证有序序列所有值在任何时刻满足 ([0, {10} ^8])
(color{#0066ff}{ 题解 })
可以线段树套平衡树
对于操作1,线段树每个区间在平衡树上找比k小的数的个数,加起来再加1就是排名
对于操作2,可以二分答案,然后通过操作1来判断(O(log^3n))
对于操作3,相当于删除再插入,注意线段树整个一条链都要改
对于操作4,5,线段树子区间答案取max和min即可
#include<bits/stdc++.h>
#define LL long long
LL in() {
char ch; LL x = 0, f = 1;
while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
return x * f;
}
const int maxn = 5e4 + 10;
const int inf = 0x7fffffff;
struct Splay {
protected:
struct node {
node *ch[2], *fa;
int val, siz;
node(node *fa = NULL, int val = 0, int siz = 0): fa(fa), val(val), siz(siz) { ch[0] = ch[1] = NULL; }
void upd() { siz = (ch[0]? ch[0]->siz : 0) + (ch[1]? ch[1]->siz : 0) + 1; }
bool isr() { return this == fa->ch[1]; }
int rk() { return ch[0]? ch[0]->siz + 1 : 1; }
}*root;
void rot(node *x) {
node *y = x->fa, *z = y->fa;
bool k = x->isr(); node *w = x->ch[!k];
if(y != root) z->ch[y->isr()] = x;
else root = x;
x->ch[!k] = y, y->ch[k] = w;
y->fa = x, x->fa = z;
if(w) w->fa = y;
y->upd(), x->upd();
}
void splay(node *o) {
while(o != root) {
if(o->fa != root) rot(o->isr() ^ o->fa->isr()? o : o->fa);
rot(o);
}
}
node *merge(node *x, node *y, node *fa) {
if(x) x->fa = fa;
if(y) y->fa = fa;
if(!x || !y) return x? x : y;
if(rand() & 1) return x->ch[1] = merge(x->ch[1], y, x), x->upd(), x;
else return y->ch[0] = merge(x, y->ch[0], y), y->upd(), y;
}
public:
int rnk(int val) {
node *o = root, *lst = root; int rank = 0;
while(o) {
lst = o;
if(val > o->val) rank += o->rk(), o = o->ch[1];
else o = o->ch[0];
}
return splay(lst), rank;
}
int kth(int k) {
node *o = root;
while(o->rk() != k) {
if(k > o->rk()) k -= o->rk(), o = o->ch[1];
else o = o->ch[0];
}
return splay(o), o->val;
}
int pre(int val) {
node *o = root, *lst = root;
while(o) {
if(o->val < val) lst = o, o = o->ch[1];
else o = o->ch[0];
}
return splay(lst), lst->val;
}
int nxt(int val) {
node *o = root, *lst = root;
while(o) {
if(o->val > val) lst = o, o = o->ch[0];
else o = o->ch[1];
}
return splay(lst), lst->val;
}
void ins(int val) {
if(!root) return (void)(root = new node(NULL, val, 1));
node *o = root, *fa = NULL;
while(o) fa = o, o = o->ch[val > o->val];
fa->ch[val > fa->val] = o = new node(fa, val, 1);
splay(o);
}
void del(int val) {
node *o = root;
while(o->val != val) o = o->ch[val > o->val];
if(!o) return;
splay(o);
root = merge(o->ch[0], o->ch[1], NULL);
delete o;
}
};
struct SGT {
private:
struct node {
int l, r;
node *ch[2];
Splay *s;
node(int l = 0, int r = 0, Splay *s = NULL): l(l), r(r), s(s) { ch[0] = ch[1] = NULL; }
}*root;
void build(node *&o, int l, int r, int *a) {
o = new node(l, r, new Splay());
for(int i = l; i <= r; i++) o->s->ins(a[i]);
if(l == r) return;
int mid = (l + r) >> 1;
build(o->ch[0], l, mid, a), build(o->ch[1], mid + 1, r, a);
}
int rnk(node *o, int l, int r, int val) {
if(o->r < l || o->l > r) return 0;
if(l <= o->l && o->r <= r) return o->s->rnk(val);
return rnk(o->ch[0], l, r, val) + rnk(o->ch[1], l, r, val);
}
int pre(node *o, int l, int r, int val) {
if(o->r < l || o->l > r) return inf;
if(l <= o->l && o->r <= r) return o->s->pre(val);
int ans = -inf;
int L = pre(o->ch[0], l, r, val);
int R = pre(o->ch[1], l, r, val);
if(L < val) ans = std::max(ans, L);
if(R < val) ans = std::max(ans, R);
return ans;
}
int nxt(node *o, int l, int r, int val) {
if(o->r < l || o->l > r) return -inf;
if(l <= o->l && o->r <= r) return o->s->nxt(val);
int ans = inf;
int L = nxt(o->ch[0], l, r, val);
int R = nxt(o->ch[1], l, r, val);
if(L > val) ans = std::min(ans, L);
if(R > val) ans = std::min(ans, R);
return ans;
}
void change(node *o, int pos, int val, int old) {
if(o->r < pos || o->l > pos) return;
o->s->del(old);
o->s->ins(val);
if(o->l == o->r) return;
change(o->ch[0], pos, val, old);
change(o->ch[1], pos, val, old);
}
public:
void build(int *a, int l, int r) { build(root, l, r, a); }
int rnk(int val, int l, int r) { return rnk(root, l, r, val) + 1; }
int kth(int k, int L, int R) {
int l = 0, r = 1e8, ans = 0;
while(l <= r) {
int mid = (l + r) >> 1;
if(rnk(mid, L, R) <= k) ans = mid, l = mid + 1;
else r = mid - 1;
}
return ans;
}
void change(int pos, int old, int now) { change(root, pos, now, old); }
int pre(int val, int l, int r) { return pre(root, l, r, val); }
int nxt(int val, int l, int r) { return nxt(root, l, r, val); }
}v;
int a[maxn];
int main() {
int p, l, r, k, n = in(), m = in();
for(int i = 1; i <= n; i++) a[i] = in();
v.build(a, 1, n);
while(m --> 0) {
p = in();
if(p == 1) l = in(), r = in(), k = in(), printf("%d
", v.rnk(k, l, r));
if(p == 2) l = in(), r = in(), k = in(), printf("%d
", v.kth(k, l, r));
if(p == 3) l = in(), k = in(), v.change(l, a[l], k), a[l] = k;
if(p == 4) l = in(), r = in(), k = in(), printf("%d
", v.pre(k, l, r));
if(p == 5) l = in(), r = in(), k = in(), printf("%d
", v.nxt(k, l, r));
}
return 0;
}