只要坑不死,就往死里坑呗。。。
二叉搜索树的模板题,本题无删除操作
核心函数:
- query_by_value(int root, int val): 根据值来查找排名
- query_by_rank(int root, int rk): 根据排名找值
- insert(int root, int val): 向树中插入val
- query_pre(int root, int val): 找值val的前驱(最大的小于val的值)
- query_suc(int root, int val): 找值val的后继(最小的大于val的值)
坑点:
- 可能会插入多个相同的值:为每一个树的结点设置cnt
- 在根据value来找rank的时候,可能树中不存在value,但是同样需要输出它的rank
- 根据rank找value时,如果找不到对应rank应该输出INF
平均查询复杂度:(O(logn))
平均插入复杂度:(O(logn))
平均构造一个n个结点的BST的复杂度:(O(nlogn))
最坏情况下的插入一个值复杂度:(O(n))
最坏情况下的查询一个值复杂度:(O(n))
最坏情况下产生一个n个结点的BST的复杂度:(O(n^2)), 比如通过一个降序的序列通过插入的方法来构造BST
代码
#include<iostream>
using namespace std;
const int N = 40010, INF = 2147483647;
struct Node{
int val;
int l, r;
int cnt;
int size;
}tr[N];
int ss = 1;
void insert(int x, int val){
tr[x].size ++;
if(ss == 1){
tr[x].val = val;
tr[x].cnt ++;
ss ++;
return;
}
if(val == tr[x].val) tr[x].cnt ++;
else if(val > tr[x].val){
if(tr[x].r) insert(tr[x].r, val);
else{
tr[x].r = ss;
tr[ss].val = val;
tr[ss].cnt ++;
tr[ss].size ++;
ss ++;
}
}else if(val < tr[x].val){
if(tr[x].l) insert(tr[x].l, val);
else{
tr[x].l = ss;
tr[ss].val = val;
tr[ss].cnt ++;
tr[ss].size ++;
ss ++;
}
}
}
int query_by_value(int x, int val){
if(val < tr[x].val){
if(tr[x].l) return query_by_value(tr[x].l, val);
return 1;
}
if(val > tr[x].val){
if(tr[x].r) return query_by_value(tr[x].r, val) + tr[tr[x].l].size + tr[x].cnt;
return tr[tr[x].l].size + tr[x].cnt + 1;
}
return tr[tr[x].l].size + 1;
}
int query_by_rank(int x, int rk){
if(x == 0) return INF;
if(rk <= tr[tr[x].l].size) return query_by_rank(tr[x].l, rk);
if(rk > tr[tr[x].l].size + tr[x].cnt) return query_by_rank(tr[x].r, rk - tr[tr[x].l].size - tr[x].cnt);
if(rk > tr[tr[x].l].size && rk <= tr[tr[x].l].size + tr[x].cnt) return tr[x].val;
}
int query_pre(int x, int val){
if(tr[x].val < val){
if(tr[x].r) return max(tr[x].val, query_pre(tr[x].r, val));
return tr[x].val;
}
if(tr[x].val >= val){
if(tr[x].l) return query_pre(tr[x].l, val);
return -INF;
}
}
int query_suc(int x, int val){
if(tr[x].val > val){
if(tr[x].l) return min(tr[x].val, query_suc(tr[x].l, val));
return tr[x].val;
}
if(tr[x].val <= val){
if(tr[x].r) return query_suc(tr[x].r, val);
return INF;
}
}
int main(){
int q;
cin >> q;
while(q --){
int k, x;
cin >> k >> x;
switch(k){
case 1: cout << query_by_value(1, x) << endl; break;
case 2: cout << query_by_rank(1, x) << endl; break;
case 3: cout << query_pre(1, x) << endl; break;
case 4: cout << query_suc(1, x) << endl; break;
case 5: insert(1, x); break;
}
}
return 0;
}
2020.12.17编辑: 增加删除操作
- 根据值删除:del_val(int u, int val)
- 根据排名删除:del_rk(int u, int rk),按排名删除可以通过query_by_rank(int u, int rk)转化为根据值删除。
需要注意的是删除需要分成三种情况:
- 删除的结点只有左子树
- 删除的结点只有右子树
- 删除的结点有左右子树
1、2两种直接将左/右子树上拉即可,第三种结点(设为u)的处理方法:找v(最大的比u的值小的结点 <=> 左子树的最右结点) 或(最小的比u大的结点 <=> 右子树的最左结点),将其值和cnt赋值给u,然后将v删除,将v删除必定属于前两种情况。
注意:方便起见,约定不存在可能导致根结点被删除的操作。如果根结点可能被删除,那么就得自己另外搞一个初始根结点(树为空的时候就存在,它的值是一个比所有插入的数都小的数),那样输出排名的时候需要-1,另外查询排名的时候要先+1。
#include<iostream>
using namespace std;
const int N = 40010, INF = 2147483647;
struct Node{
int val;
int cnt; // 当前结点存储的相同的val的个数
int l, r;
int size; // 以当前结点为根的树的大小
}tr[N];
int ss = 1;
void insert(int u, int val){
tr[u].size ++;
if(ss == 1){
tr[u].val = val;
tr[u].cnt ++;
ss ++;
return;
}
if(val == tr[u].val) tr[u].cnt ++;
else if(val < tr[u].val){
if(tr[u].l) insert(tr[u].l, val);
else{
tr[u].l = ss;
tr[ss].val = val;
tr[ss].size ++;
tr[ss].cnt ++;
ss ++;
}
}else if(val > tr[u].val){
if(tr[u].r) insert(tr[u].r, val);
else{
tr[u].r = ss;
tr[ss].val = val;
tr[ss].size ++;
tr[ss].cnt ++;
ss ++;
}
}
}
int del_rk(int u, int rk){ // 找rk在树上的位置
if(u == 0) return 0;
int l = tr[tr[u].l].size + 1;
int r = l + tr[u].cnt - 1;
// cout << tr[u].val << endl;
// cout << l << ' ' << r << endl;
if(rk >= l && rk <= r){
tr[u].cnt --;
tr[u].size --;
if(tr[u].cnt) return 1;
if(tr[u].l && tr[u].r){
int p = tr[u].l, q = p;
while(tr[p].r) p = tr[p].r;
tr[u].val = tr[p].val;
tr[u].cnt = tr[p].cnt;
p = q;
while(tr[p].r){
q = p;
tr[p].size -= tr[u].cnt;
p = tr[p].r;
}
if(p == q) tr[u].l = 0;
else tr[q].r = tr[p].l;
}
return 1;
}
if(rk < l && del_rk(tr[u].l, rk)){
tr[u].size --;
int l = tr[u].l;
if(tr[l].cnt == 0) tr[u].l = tr[l].l ? tr[l].l : (tr[l].r ? tr[l].r : 0);
return 1;
}
if(rk > r && del_rk(tr[u].r, rk - l)){
tr[u].size --;
int r = tr[u].r;
if(tr[r].cnt == 0) tr[u].r = tr[r].l ? tr[r].l : (tr[r].r ? tr[r].r : 0);
return 1;
}
return 0;
}
int del_val(int u, int val){ // 找val在树上的位置
if(u == 0) return 0;
if(val == tr[u].val){
tr[u].cnt --;
tr[u].size --;
if(tr[u].cnt) return 1;
if(tr[u].l && tr[u].r){
int p = tr[u].l, q = p;
while(tr[p].r) p = tr[p].r;
tr[u].val = tr[p].val;
tr[u].cnt = tr[p].cnt;
p = q;
while(tr[p].r){
q = p;
tr[p].size -= tr[u].cnt;
p = tr[p].r;
}
if(p == q) tr[u].l = 0;
else tr[q].r = tr[p].l;
}
return 1;
}
if(val > tr[u].val && del_val(tr[u].r, val)){
tr[u].size --;
int r = tr[u].r;
if(tr[r].cnt == 0) tr[u].r = tr[r].l ? tr[r].l : (tr[r].r ? tr[r].r : 0);
return 1;
}
if(val < tr[u].val && del_val(tr[u].l, val)){
tr[u].size --;
int l = tr[u].l;
if(tr[l].cnt == 0) tr[u].l = tr[l].l ? tr[l].l : (tr[l].r ? tr[l].r : 0);
return 1;
}
return 0;
}
int query_rk(int u, int val){
if(u == 0) return 1;
if(val == tr[u].val) return tr[tr[u].l].size + 1;
if(val < tr[u].val) return query_rk(tr[u].l, val);
return tr[tr[u].l].size + tr[u].cnt + query_rk(tr[u].r, val);
}
int query_val(int u, int rk){
if(u == 0) return INF;
int l = tr[tr[u].l].size + 1;
int r = l + tr[u].cnt - 1;
if(rk < l) return query_val(tr[u].l, rk);
if(rk > r) return query_val(tr[u].r, rk - r);
return tr[u].val;
}
int query_suc(int u, int val){
if(u == 0) return INF;
if(val >= tr[u].val) return query_suc(tr[u].r, val);
return min(tr[u].val, query_suc(tr[u].l, val));
}
int query_pre(int u, int val){
if(u == 0) return -INF;
if(val <= tr[u].val) return query_pre(tr[u].l, val);
return max(tr[u].val, query_pre(tr[u].r, val));
}
int main(){
int q;
cin >> q;
while(q --){
int k, x;
cin >> k >> x;
switch(k){
case 1: cout << query_rk(1, x) << endl; break;
case 2: cout << query_val(1, x) << endl; break;
case 3: cout << query_pre(1, x) << endl; break;
case 4: cout << query_suc(1, x) << endl; break;
case 5: insert(1, x); break;
case 6:
if(del_val(1, x)) printf("delete succeeded, size: %d
", tr[1].size);
else puts("delete failed");
break;
case 7:
if(del_rk(1, x)) printf("delete succeeded, size: %d
", tr[1].size);
else puts("delete failed");
break;
}
}
return 0;
}
/*
10
5 5
5 3
5 6
5 2
5 4
1 5
7 1
7 2
6 5
1 6
*/
/*
13
5 5
5 3
5 6
5 2
5 4
1 5
2 1
7 1
2 2
7 1
7 1
6 4
1 -100
*/
上面的代码:口胡最坏单次删除复杂度为(O(n)), 最好情况下删除一个叶子复杂度(O(logn))