1 ////二叉查找树,为了实现方便,给每个节点添加了一个指向父节点的指针 2 #include<iostream> 3 #include<vector> 4 #include<ctime> 5 #include<cstdlib> 6 7 using namespace std; 8 9 template<class T> 10 class BinarySearchTree 11 { 12 private: 13 struct Node 14 { 15 T data; 16 int deep; 17 Node *left; 18 Node *right; 19 Node *prev; 20 Node(T val,int deep) 21 { 22 data = val; 23 deep = 0; 24 left = NULL; 25 right = NULL; 26 prev = NULL; 27 } 28 29 private: 30 Node() 31 { 32 } 33 }; 34 Node *root; 35 int size; 36 37 public: 38 BinarySearchTree() 39 { 40 root = NULL; 41 size = 0; 42 } 43 ~BinarySearchTree() 44 { 45 clear(root); 46 root = NULL; 47 size = 0; 48 } 49 T min(Node *node) const 50 { 51 if(node->left == NULL) 52 return node->data; 53 else 54 return min(node->left); 55 } 56 T max(Node *node) const 57 { 58 if(node->right == NULL) 59 return node->data; 60 else 61 return max(node->right); 62 } 63 64 Node *insert(Node *& node,T val) 65 { 66 if(size == 0 && node == NULL) 67 { 68 root = new Node(val,0); 69 root->prev = NULL; 70 size++; 71 return root; 72 } 73 if(size != 0 && node == NULL) 74 { 75 cout<<"ERROR "; 76 return NULL; 77 } 78 if(val > node->data) 79 { 80 if(node->right != NULL) 81 return insert(node->right,val); 82 else 83 { 84 Node *tmp = new Node(val,node->deep+1); 85 tmp->prev = node; 86 node->right = tmp; 87 size++; 88 return tmp; 89 } 90 } 91 else if(val < node->data) 92 { 93 if(node->left != NULL) 94 return insert(node->left,val); 95 else 96 { 97 Node *tmp = new Node(val,node->deep+1); 98 tmp->prev = node; 99 node->left = tmp; 100 size ++; 101 return tmp; 102 } 103 } 104 else if(val == node->data) 105 { 106 } 107 } 108 109 bool contain(Node *node,T val) const 110 { 111 if(node == NULL) 112 return false; 113 114 if(val > node->data) 115 return contain(node->right,val); 116 else if(val < node->data) 117 return contain(node->left,val); 118 else 119 return true; 120 } 121 void removeNode(Node *node) 122 { 123 if(node->left == NULL && node->right == NULL) 124 { 125 if(node->prev->left == node) 126 node->prev->left = NULL; 127 else 128 node->prev->right = NULL; 129 130 delete node; 131 size--; 132 } 133 else if(node->left == NULL) 134 { 135 node->right->prev = node->prev; 136 if(node->prev->left == node) 137 node->prev->left = node->right; 138 else 139 node->prev->right = node->right; 140 141 decDeep(node->right); 142 delete node; 143 size--; 144 } 145 else if(node->right == NULL) 146 { 147 node->left->prev = node->prev; 148 if(node->prev->left == node) 149 node->prev->left = node->left; 150 else 151 node->prev->right = node->left; 152 153 decDeep(node->left); 154 delete node; 155 size--; 156 } 157 else 158 { 159 Node *p = node->right; 160 while(p->left != NULL) 161 { 162 p=p->left; 163 } 164 node->data = p->data; 165 if(p->right != NULL) 166 { 167 p->prev->left = p->right; 168 p->right->prev = p->prev; 169 decDeep(p->right); 170 delete p; 171 size--; 172 } 173 else 174 { 175 p->prev->left = NULL; 176 delete p; 177 size--; 178 } 179 } 180 } 181 void decDeep(Node *node) 182 { 183 node->deep--; 184 if(node->left != NULL) 185 decDeep(node->left); 186 if(node->right != NULL) 187 decDeep(node->right); 188 } 189 void remove(T val) 190 { 191 Node * p=root; 192 while(1) 193 { 194 if(val > p->data) 195 p = p->right; 196 else if(val < p->data) 197 p = p->left; 198 else if(val == p->data) 199 { 200 201 removeNode(p); 202 return; 203 } 204 } 205 } 206 void clear(Node*node) 207 { 208 if(node->left != NULL) 209 clear(node->left); 210 if(node->right != NULL) 211 clear(node->right); 212 213 delete node; 214 node = NULL; 215 } 216 void print(Node *node) 217 { 218 if(node == NULL) 219 return; 220 cout<<node->data<< " "; 221 if(node->left != NULL) 222 print(node->left); 223 if(node->right != NULL) 224 print(node->right); 225 } 226 void insert(T val) 227 { 228 insert(root,val); 229 } 230 void print() 231 { 232 print(root); 233 cout<<" "; 234 } 235 }; 236 237 int main() 238 { 239 BinarySearchTree<int> tree; 240 tree.insert(10); 241 tree.insert(1); 242 tree.insert(11); 243 tree.insert(9); 244 tree.insert(8); 245 tree.print(); 246 cout<<" "; 247 tree.remove(9); 248 tree.print(); 249 250 return 0; 251 }