某辣鸡考研党复习到平衡树时突然心血来潮想自己实现一下AVL树QAQ。快一年没敲代码了码力下降严重,断断续续写了好久QAQ。写了快300行,不过是全凭自己感觉写的,也算是完成了当年没完成的心愿吧(自己独立写出一种平衡树)。
代码:
#include <bits/stdc++.h> #define ls(pos) tr[pos].ch[0] #define rs(pos) tr[pos].ch[1] #define Fa(pos) tr[pos].fa using namespace std; const int maxn = 101010; struct AVL_tree { int tot = 0, root = 0; struct node { int ch[2], fa; int bal, val, dep, sz, cnt; void init() { ch[0] = ch[1] = fa = bal = val = dep = sz = cnt = 0; } }; node tr[maxn]; int Next_pos, pre_pos, Rank, ans; void init() { memset(tr, 0, sizeof(node)); } int creat(int val, int fa) { int ret = 0; // if(st.size()) { // ret = st.top(); // st.pop(); // } else { ret = ++tot; // } tr[ret].val = val; tr[ret].fa = fa; tr[ret].sz = 1; tr[ret].bal = 0; tr[ret].dep = 1; tr[ret].cnt = 1; return ret; } int son (int pos) { if(ls(Fa(pos)) == pos) return 0; return 1; } void rrotate(int pos) { if(Fa(pos) != 0) { if(son(pos) == 0) { ls(Fa(pos)) = ls(pos); } else { rs(Fa(pos)) = ls(pos); } } Fa(ls(pos)) = Fa(pos); Fa(pos) = ls(pos); ls(pos) = rs(ls(pos)); if(ls(pos)) Fa(ls(pos)) = pos; rs(Fa(pos)) = pos; maintain1(pos); maintain1(Fa(pos)); if(pos == root) { root = Fa(pos); } } void lrotate(int pos) { if(Fa(pos) != 0) { if(son(pos) == 0) { ls(Fa(pos)) = rs(pos); } else { rs(Fa(pos)) = rs(pos); } } Fa(rs(pos)) = Fa(pos); Fa(pos) = rs(pos); rs(pos) = ls(rs(pos)); if(rs(pos)) Fa(rs(pos)) = pos; ls(Fa(pos)) = pos; maintain1(pos); maintain1(Fa(pos)); if(pos == root) { root = Fa(pos); } } void rotate(int pos) { if(tr[pos].bal > 1) { if(tr[ls(pos)].bal > 0) { rrotate(pos); } else { lrotate(ls(pos)); rrotate(pos); } } else { if(tr[rs(pos)].bal < 0) { lrotate(pos); } else { rrotate(rs(pos)); lrotate(pos); } } } void maintain1(int pos) { if(pos == 0) return; tr[pos].dep = max(tr[ls(pos)].dep, tr[rs(pos)].dep) + 1; tr[pos].bal = tr[ls(pos)].dep - tr[rs(pos)].dep; tr[pos].sz = tr[ls(pos)].sz + tr[rs(pos)].sz + tr[pos].cnt; } void maintain(int pos) { if(pos == 0) return; maintain1(pos); if(tr[pos].bal > 1 || tr[pos].bal < -1) { rotate(pos); } } void insert(int pos, int val) { if(tr[pos].val == val) { tr[pos].cnt++; } else { if(val > tr[pos].val) { if(rs(pos) == 0) { rs(pos) = creat(val, pos); } else { insert(rs(pos), val); } } else { if(ls(pos) == 0) { ls(pos) = creat(val, pos); } else { insert(ls(pos), val); } } } maintain(pos); } int find(int pos, int x) { if(pos == 0) return pos; if(tr[pos].val == x) { return pos; } if(tr[pos].val > x) return find(ls(pos), x); else return find(rs(pos), x); } void pre(int pos, int x) { if(tr[pos].val >= x) { if(ls(pos)) pre(ls(pos), x); } else { pre_pos = pos; if(rs(pos)) pre(rs(pos), x); } } void Next(int pos, int x) { if(tr[pos].val <= x) { if(rs(pos)) Next(rs(pos), x); } else { Next_pos = pos; if(ls(pos)) Next(ls(pos), x); } } bool del(int pos) { bool ret = false; int s = son(pos), tmp = Fa(pos); if(!ls(pos) && !rs(pos)) { if(s == 0) ls(Fa(pos)) = 0; else rs(Fa(pos)) = 0; if(root == pos) root = 0; ret = true; } else if(ls(pos) == 0) { if(s == 0) ls(Fa(pos)) = rs(pos); else rs(Fa(pos)) = rs(pos); Fa(rs(pos)) = Fa(pos); if(pos == root) root = rs(pos); ret = true; } else if(rs(pos) == 0) { if(s == 0) ls(Fa(pos)) = ls(pos); else rs(Fa(pos)) = ls(pos); Fa(ls(pos)) = Fa(pos); if(pos == root) root = ls(pos); ret = true; } if(ret) { // st.push(pos); tr[pos].init(); return true; } return false; } void rank_of_val(int pos, int val) { if(!pos) return; if(tr[pos].val < val) { Rank += tr[ls(pos)].sz + tr[pos].cnt; rank_of_val(rs(pos), val); } else { rank_of_val(ls(pos), val); } } void val_of_rank(int pos, int remain) { if(!pos) return; if(tr[ls(pos)].sz < remain) { if(tr[ls(pos)].sz + tr[pos].cnt >= remain) { ans = tr[pos].val; return; } else { remain -= tr[ls(pos)].sz + tr[pos].cnt; val_of_rank(rs(pos), remain); } } else { val_of_rank(ls(pos), remain); } } void erase(int pos) { int t = Fa(pos); if(tr[pos].cnt == 1) { if(!del(pos)) { //int tmp = Next(root, pos); int tmp = ls(pos); while(rs(tmp)) tmp = rs(tmp); t = Fa(tmp); tr[pos].val = tr[tmp].val; tr[pos].cnt = tr[tmp].cnt; del(tmp); } } else { tr[pos].cnt--; maintain1(pos); } while(t) { maintain(t); t = Fa(t); } } }; AVL_tree solve; int main() { srand(time(0)); int n, x, y; cin >> n; solve.init(); for (int i = 1; i <= n; i++) { cin >> x >> y; if(x == 1) { if(solve.root == 0) { solve.root = solve.creat(y, solve.root); } else { solve.insert(solve.root, y); } } else if(x == 2) { y = solve.find(solve.root, y); if(y == 0) { printf("miss "); } else { solve.erase(y); } } else if (x == 3) { solve.Rank = 0; solve.rank_of_val(solve.root, y); printf("%d ", solve.Rank + 1); } else if(x == 4) { solve.ans = 0; solve.val_of_rank(solve.root, y); printf("%d ", solve.ans); } else if (x == 5) { solve.pre_pos = -1; solve.pre(solve.root, y); if(solve.pre_pos == -1) printf("not found "); printf("%d ", solve.tr[solve.pre_pos].val); } else if(x == 6) { solve.Next_pos = -1; solve.Next(solve.root, y); if(solve.Next_pos == -1) printf("not found "); printf("%d ", solve.tr[solve.Next_pos].val); } // printf("root : %d ", solve.tr[solve.root].sz); } }