#include<iostream>
#include<stack>
template <class T>
class node {
public:
T data;
int height;
node<T>* left;
node<T>* right;
node() {
left = nullptr;
right = nullptr;
height = 1;
}
node(T x) {
left = nullptr;
right = nullptr;
height = 1;
data = x;
}
};
template <class T>
class AVLtree {
node<T> *root;
int quantity;
//-------------------------------------------------
//这部分属于类的内部操作函数,并不需要暴露给使用者
//-------------------------------------------------
int max(int a, int b) {
return a > b ? a : b;
}
//返回高度,主要作用是定义空指针的高度为0
int height(node<T>*x) {
if (x == nullptr)return 0;
return x->height;
}
int update_h(node<T>*x) {
return max(height(x->left), height(x->right)) + 1;
}
//ll型旋转
node<T>* ll_rotate(node<T>* x) {
node<T>* y = x->left;
x->left = y->right;
y->right = x;
x->height = update_h(x);
y->height = update_h(y);
return y;
}
//rr型旋转
node<T>* rr_rotate(node<T>* x) {
node<T>* y = x->right;
x->right = y->left;
y->left = x;
x->height = update_h(x);
y->height = update_h(y);
return y;
}
//lr型旋转
node<T>* lr_rotate(node<T>* x) {
x->left = rr_rotate(x->left);
return ll_rotate(x);
}
//rl型旋转
node<T>* rl_rotate(node<T>* x) {
x->right = ll_rotate(x->right);
return rr_rotate(x);
}
//获取两子树高度差
int getBalance(node<T>* x) {
if (x == nullptr)return 0;
return height(x->left) - height(x->right);
}
//向节点x插入元素
node<T>* insert(node<T>* p, T x) {
if (p == nullptr) {
quantity++;
return p = new node<T>(x);
}
//要注意全部使用<号,这样T只要重载小于号就好
if (x < p->data)
p->left=insert(p->left, x);
else if (p->data < x)
p->right=insert(p->right, x);
else return p;
p->height = update_h(p);
int b = getBalance(p);
if (1 < b && (x < p->left->data)) {
return ll_rotate(p);
}
else if (b < -1 && (p->right->data < x)) {
return rr_rotate(p);
}
else if (1 < b && (p->right->data < x)){
return lr_rotate(p);
}
else if (b < -1 && (x < p->left->data)){
return rl_rotate(p);
}
return p;
}
T* find(T x,node<T> *p) {
if (x < p->data)
return find(x,p->left );
else if (p->data < x)
return find(x,p->right);
else return &p->data;
}
node<T> *minValueNode(node<T>* p){
node<T>* current = p;
while (current->left != nullptr)
current = current->left;
return current;
}
node<T> *erase(T x,node<T>* p) {
if (p == nullptr)return nullptr;
if (p->data < x) {
p->right=erase(x, p->right);
}
else if (x < p->data) {
p->left=erase(x, p->left);
}
else {
if (p->left == nullptr||p->right == nullptr) {
node<T>*tmp = p->left ? p->left : p->right;
delete p;
p = tmp;
}
else {
node<T> tmp = minValueNode(p->right);
p->data = tmp.data;
p->right = erase(tmp.data, p->right);
}
}
if (p == nullptr)return nullptr;
p->height = update_h(p);
int b = getBalance(p);
if (1 < b && (-1<getBalance(p->left))) {
return ll_rotate(p);
}
else if (b < -1 && (getBalance(p->right) < 1)) {
return rr_rotate(p);
}
else if (1 < b && (getBalance(p->left) < 0)) {
return lr_rotate(p);
}
else if (b < -1 && (0<getBalance(p->right) )) {
return rl_rotate(p);
}
return p;
}
class iterator {
node<T> *p;
stack <node<T>*> inorder;
public:
iterator(node<T> * x){
p = x;
while (p != nullptr) {
inorder.push(p);
p=p->left;
}
if (!inorder.empty()) {
p = inorder.top();
inorder.pop();
}
}
bool operator==(const iterator x)const { return p == x.p; }
bool operator!=(const iterator x)const {
return p != x.p;
}
T& operator*()const { return p->data; }
node<T> *operator->() const { return p; }
iterator &operator++(){
if (inorder.empty()) {
iterator tmp = iterator(nullptr);
*this = tmp;
return *this;
}
p=inorder.top();
inorder.pop();
if (p->right != nullptr) {
node<T> *tmp=p->right;
while (tmp != nullptr) {
inorder.push(tmp);
tmp = tmp->left;
}
}
return *this;
}
iterator &operator++(int) {
node<T>* x = p;
++*(this);
return *this;
}
};
public:
AVLtree() {
root =nullptr;
quantity = 0;
}
//插入元素
void insert(T x) {root=insert(root, x);}
//删除元素
void erase(T x) { if (erase(x, root)) quantity--; }
//寻找元素x,若无则返回空指针
T* find(T x) { return find( x, root); }
//容器元素个数
int size() { return quantity; }
//容器是否为空
bool empty() { return quantity == 0; }
iterator begin() {return iterator(root);}
iterator end() { return iterator(nullptr); }
};