#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <cassert>
using namespace std;
template<typename Key, typename Value>
class BST {
public :
BST() {
root = NULL;
count = 0;
}
~BST() {
// TODO : ~BST()
destroy(root);
}
int size() {
return count;
}
bool isEmpty() {
return count == 0;
}
// 插入结点
void insert(Key key, Value value) {
root = insert(root, key, value);
}
// 是否包含该键值的结点
bool contain(Key key) {
return contain(root, key);
}
// 返回值为指针类型: 若未查找到, 返回空
// 若查找到, 返回指向value地址的指针
Value * search(Key key) {
return search(root, key);
}
//前序遍历
void preOrder() {
preOrder(root);
}
// 中序遍历
void inOrder() {
inOrder(root);
}
// 后序遍历
void postOrder() {
postOrder(root);
}
// 层序遍历
void levelOrder() {
queue<Node*> q;
q.push(root);
while(!q.empty()) {
Node * node = q.front();
q.pop();
cout << node->key << endl;
if(node->left)
q.push(node->left);
if(node->right)
q.push(node->right);
}
}
// 寻找最小的键值
Key minimum() {
assert(count != 0);
Node * minNode = minimum(root);
return minNode->key;
}
// 寻找最大的键值
Key maximum() {
assert(count != 0);
Node * maxNode = maximum(root);
return maxNode->key;
}
// 从二叉树中删除最小值所在节点
void removeMin() {
if(root)
root = removeMin(root);
}
// 从二叉树中删除最大值所在的结点
void removeMax() {
if(root)
root = removeMax(root);
}
// 从二叉树中删除键值为key的结点
void remove(Key key) {
root = remove(root, key);
}
private :
struct Node {
Key key;
Value value;
Node * left;
Node * right;
Node(Key key, Value value) {
this->key = key;
this->value = value;
this->left = this->right = NULL;
}
Node(Node *node) {
this->key = node->key;
this->value = node->value;
this->left = node->left;
this->right = node->right;
}
};
Node * root;
int count; // 二分搜索树的节点个数
Node * insert(Node * node, Key key, Value value) {
if(node == NULL) {
count ++;
return new Node(key, value);
}
if(key == node->key)
node->value = value;
else if(key < node->key)
node->left = insert(node->left, key, value);
else
node->right = insert(node->right, key, value);
return node;
}
bool contain(Node * node, Key key) {
if(node == NULL)
return false;
if(key == node->key)
return true;
else if(key < node->key)
return contain(node->left, key);
else
return contain(node->right, key);
}
// 在以node为根的二叉搜索树中查找key所对应的value
Value * search(Node * node, Key key) {
if(node == NULL)
return NULL;
if(key == node->key)
return &(node->value);
else if(key < node->key)
return search(node->left, key);
else
return search(node->right, key);
}
void preOrder(Node * node) {
if(node != NULL) {
cout << node->key << endl;
preOrder(node->left);
preOrder(node->right);
}
return ;
}
void inOrder(Node * node) {
if(node != NULL) {
inOrder(node->left);
cout << node->key << endl;
inOrder(node->right);
}
return ;
}
void postOrder(Node * node) {
if(node != NULL) {
postOrder(node->left);
postOrder(node->right);
cout << node->key << endl;
}
return ;
}
void destroy(Node * node) {
// 先删除左子树, 再删除右子树, 最后删除根节点
if(node != NULL) {
destroy(node->left);
destroy(node->right);
delete node;
count --;
}
}
// 在以node为根的二叉搜索树中, 返回最小键值的结点
Node * minimum(Node * node) {
if(node->left == NULL)
return node;
return minumum(node->left);
}
// 在以node为根的二叉搜索树中, 返回最大键值的结点
Node maximum(Node * node) {
if(node->right == NULL)
return node;
return maximum(node->right);
}
// 删除以node为根的二分搜索树中的最小节点
// 返回删除结点后的新的二分搜索树的根
Node * removeMin(Node * node) {
if(node->left == NULL) {
Node * rightNode = node->right;
delete node;
count --;
return rightNode;
}
node->left = removeMin(node->left);
return node;
}
// 删除以node为根的二分搜索树中的最大节点
// 返回删除结点后的新的二分搜索树的根
Node * removeMax(Node * node) {
if(node->right == NULL) {
Node * leftNode = node->left;
delete node;
count --;
return leftNode;
}
node->right = removeMax(node->right);
return node;
}
// 删除以node为根的二分搜索树中键值为key的结点
// 返回删除结点后的二分搜索树的根
Node * remove(Node * node, Key key) {
if(node == NULL)
return NULL;
if(key < node->key) {
node->left = remove(node->left, key);
return node;
}
else if(key > node->key) {
node->right = remove(node->right, key);
return node;
}
else { // key == node->key
if(node->left == NULL) {
Node * rightNode = node->right;
delete node;
count --;
return rightNode;
}
if(node->right == NULL) {
Node * leftNode = node->left;
delete node;
count --;
return leftNode;
}
// node->left != NULL && node->right != NULL
Node * successor = new Node(minimum(node->right));
count ++;
successor->right = removeMin(node->right);
successor->left = node->left;
delete node;
count --;
return successor;
}
}
};