zoukankan      html  css  js  c++  java
  • 二叉查找树,AVL,红黑树的Python实现

    简介:本文重点给出面试高频二叉树的实现

    二叉查找树,顾名思义,就是用于辅助我们进行查找的树状数据结构。

    在讲本文的主角之前,先讲一下其他与查询相关的数据结构。

    首先,无序表,查找的时间复杂度为O(n).

    有序表(预排序),查找(二分查找)的时间复杂度为O(logn),但是插入和删除的时间复杂度为O(n)

    那如何降低插入和删除的时间复杂度呢,我们本文的主角就登场了

    一、二叉查找树(二叉搜索树/二叉排序树,它的称呼比较多)

    定义:一棵二叉查找树是一棵二叉树,其中每个结点都包含一个键(以及相关联的值)且每个结点的键都大于其左子树中的任意结点的键而小于右子树的任意结点的键。

    它能够实现O(logn)的查找平均时间复杂度,且插入和删除的平均时间复杂度也为O(logn).

    先给出二叉搜索树结点的定义

    class BSTNode:
        """树的结点"""
        def __init__(self, val):
            self.val = val
            self.left = None
            self.right = None

    再给出二叉搜索树的实现

    class BST:
        """二叉搜索树"""
    
        def __init__(self):
            self.root = None
            
        def find(self, val):
            """查找方法,返回布尔值"""
            return self._find(self.root, val)
        
        def _find(self, node, val):
            if node is None:
                return False
            if node.val == val:
                return True
            if val<node.val:
                return self._find(node.left,val)
            else:  
                return self._find(node.right,val)
            
        def insert(self, val):
            """插入方法,当前二叉搜索树定义为不包含重复key的树,所以当插入重复key时会失败"""
            if self.find(val):
                return False
            if not self.root:
                self.root = node=BSTNode(val)
            else:
                self._insert(self.root, val)
            return True
        
        def _insert(self, node, val):
            if not node: 
                node=BSTNode(val)
            else:
                if val<node.val:
                    node.left=self._insert(node.left,val)
                elif val>node.val:
                    node.right=self._insert(node.right,val)
            return node
      
        def findmin(self):
            """查找key值最小的结点并返回该结点"""
            if not root:
                return None
            else:
                return self._findmin(self.root)
        
        def _findmin(self, node):
            if node.left:
                return self._findmin(node.left)
            else:
                return node
        
        def delete(self, val):
            """删除,返回布尔值"""
            if not self.find(val):
                return False
            else:
                self._delete(self.root, val)
                return True
        
        def _delete(self, node, val):
            """删除方法,最为复杂"""
            if not node: 
                return None
            if val<node.val:
                node.left = self._delete(node.left,val)
            elif val>node.val:
                node.right = self._delete(node.right,val)
            else:
                #当待删除结点有左右孩子时,选择左子树最大结点或者右子树最小节点用来替换该节点的值
                if node.left and node.right:
                    node.val=self._findmin(node.right).val
                    #替换完成后把对应的左子树最大节点或者右子树最小节点进行递归删除
                    node.right = self._delete(node.right,node.val)
                else:
                    #当待删除结点只有一个孩子时,直接用孩子替换该结点
                    node = node.left if not node.right else node.right
            return node

    问题:在极端情况下,比如二叉查找树的所有节点只有左子树的情况,所有操作的时间复杂度都会变为O(n),根本原因是二叉树的结构不平衡。、

    二、AVL树

    为了解决二叉查找树不平衡情况下糟糕的时间复杂度,我们在其基础上增加了一个约束:每个结点的左右子树的高度差不能大于1,具备该条件的二叉查找树即为AVL树

    AVL树保证了查找、插入、删除的最差时间复杂度均为O(logn)

    class AVLNode:
        """树的结点"""
        def __init__(self, val):
            self.val = val
            self.left = None
            self.right = None
            self.height = None
    class AVL:
        """二叉平衡树之AVL"""
        def __init__(self):
            self.root = None
        
        def _height(self, node):
            """返回当前结点的高度"""
            return -1 if not node.height else node.height
        
        def find(self, val):
            return self._find(self.root, val)
        
        def _find(self, node, val):
            if not node:
                return False
            elif val == node.val:
                return True
            elif val < node.val:
                return self._find(node.left, val)
            else:
                return self._find(node.right, val)
        
        def insert(self, val):
            if self.find(val):
                return False
            if not self.root:
                self.root = AVLNode(val)
            else:
                self._insert(self.root, val)
            return True
        
        def _insert(self, node, val):
            if not node:
                node = AVLNode(val)
            elif val < node.val:
                node.left = self._insert(node.left, val)
            elif val > node.val:
                node.right = self._insert(node.right, val)
            return self._balance(node)
        
        def _balance(self, node):
            """对结点进行高度平衡"""
            if not node:
                return node
            if self._height(node.left)-self._height(node.right)>1:
                if self._height(node.left.left)>self._height(node.left.right):
                    self._rotate_single_left(node)
                else:
                    self._rotate_double_left(node)
            elif self._height(node.right)-self._height(node.left)>1:
                if self._height(node.right.right)>self._height(node.right.left):
                    self._rotate_single_right(node)
                else:
                    self._rotate_double_right(node)
            node.height = max(height(node.left), height(node.right))+1
            return node
        
        def _rotate_single_left(self, node):
            new_node = node.left
            node.left = new_node.right
            new_node.right = node
            node.height = max(self._height(node.left), self._height(node.right))+1
            new_node.height = max(self._height(new_node.left), self._height(new_node.right))+1
            return new_node
        
        def _rotate_single_right(self, node):
            new_node = node.right
            node.right = new_node.left
            new_node.left = node
            node.height = max(self._height(node.left), self._height(node.right))+1
            new_node.height = max(self._height(new_node.left), self._height(new_node.right))+1
            return new_node
        
        def _rotate_double_left(self, node):
            node.left = self._rotate_single_right(node.left)
            return self._rotate_single_left(node)
        
        def _rotate_double_right(self, node):
            node.right = self._rotate_single_left(node.right)
            return self._rotate_single_right(node)
        
        def remove(self, val):
            if not self.find(val):
                return False
            else:
                self.remove(self, self.root, val)
                return True
        
        def findmin(self):
            return self._findmin(self.root)
        
        def _findmin(self, node):
            if node.left:
                return self._findmin(node.left)
            else:
                return node
            
        def _remove(self, node, val):
            if not node:
                return None
            if val < node.val:
                self._remove(node.left, val)
            elif val > node.val:
                self._remove(node.right, val)
            else:
                if node.left and node.right:
                    node.val = self._findmin(node.right).val
                    node.right = self._remove(node.right, node.val)
                else:
                    node = node.left if not node.right else node.right
            return self._balance(node)

    我们仔细分析一下,查询所用时间复杂度都耗费在递归过程上,为O(logn)

    插入所用时间复杂度耗费在递归(logn)和旋转,所以插入时间复杂度O(logn)

    删除所用时间复杂度耗费在递归(logn)和旋转(至多logn次,当删除结点导致根节点高度不平衡时,需要从当前结点到根节点的路径上递归进行平衡),总的删除时间复杂度为O(logn)

    问题:我们可以看出,在所有过程中,删除过程中的旋转是相对最耗费时间的,那我们有没有办法减少总的旋转次数呢?

    三、红黑树

    针对上述问题,我们思考,之所以旋转次数多是因为AVL的平衡条件过于苛刻(左右子树之差小于等于1),我们能不能通过放宽这个限制来取得更加优秀的删除效率呢,当然可以,那就是接下来要讲的红黑树。

    红黑树通过着色的方式实现2-3树(关于2-3树可以看书,不具体介绍了),红黑树规定红色链接必为左连接,红色链接连接的两个点构成一个3-结点。

    class RBTNode:
        def __init__(self, val, color, n):
            self.val = val
            self.color = color
            self.left = None
            self.right = None
    class RBT:
        def __init__(self):
            self.root = None
        
        def _left_rotate(self, node):
            """左旋,把红色右链接转化为左链接"""
            new_node = node.right
            node.right = new_node.left
            new_node.left = node
            node.color = 'RED'
            return new_node
        
        def _right_rotate(self, node):
            """右旋,把红色左链接转化为右链接"""
            new_node = node.left
            node.left = new_node.right
            new_node.right = node
            new_node.color = 'RED'
            
        def _flipcolors(self, node):
            """反转当前结点及其子节点的颜色"""
            if node.color = 'BLACK':
                node.color = 'RED'
                node.left.color = 'BLACK'
                node.right.color = 'BLACK'
            else:
                node.color = 'BLACK'
                node.left.color = 'RED'
                node.right.color = 'RED'
            
        def find(self, val):
            return self._find(self.root, val)
        
        def _find(self, node, val):
            if not node:
                return False
            if node.val == val:
                return True
            elif val < nodel.val:
                return _find(node.left, val)
            else:
                return _find(node.right, val)
            
        def _isred(self, node):
            return True if node.color == 'RED' else False
        
        def insert(self, val):
            if self.find(val):
                return False
            else:
                self._insert(self.root, val)
                return True
            
        def _insert(self, node, val):
            if not node:
                return RBTNode(val, 'RED')
            if val < node.val:
                node.left = self._insert(node.left, val)
            elif val > node.val:
                node.right = self._insert(node.right, val)
            self._balance(node)
            
            return node
        
        def _balance(self, node):
            if self._isred(node.right) and not self._isred(node.left):
                node = self._left_rotate(node)
            if self._isred(node.left) and self._isred(node.left.left):
                node = self._right_rotate(node)
            if self._isred(node.left) and self._isred(node.right):
                self._flipcolors(node)
            return node
        
        def delete_min(self):
            if not self._isred(self.root.left) and not self._isred(self.root.right):
                self.root.color = 'RED'
            self.root = _delete_min(self.root)
            if self.root:
                self.root.color = 'BLACK'
        
        def _delete_min(self, node):
            if not node.left:
                return None
            if not self._isred and not self._isred(node.left.left):
                node = self._move_red_left(node)
            node.left = self._delete_min(node.left)
            return self._balance(node)
        
        def delete_max(self):
            if not self._isred(self.root.left) and not self._isred(self.root.right):
                self.root.color = 'RED'
            self.root = self._delete_max(self.root)
            if not self.root:
                self.color = 'BLACK'
        
        def _delete_max(self, node):
            if self._isred(node.left):
                node = self._right_rotate(node)
            if not node.right:
                return None
            if not self._isred(node.right) and not self._isred(node.right.left):
                node = self._move_red_right(node)
            node.right = self._delete_max(node.right)
            return self._balance(node)
        
        def _move_red_right(self, node):
            self._flipcolors(node)
            if not self._isred(node.left.left):
                node = self._right_rotate(node)
            return node
        
        def delete(self, val):
            if not self._isred(self.root.left) and not self._isred(self.root.right):
                self.root.color = 'RED'
            self.root = self._delete(self.root, val)
            if not self.root:
                self.color = 'BLACK'
                
        def _delete(self, node, val):
            if val < node.val:
                if not self._isred and not self._isred(node.left.left):
                    node = self._move_red_left(node)
                node.left = self._delete(node.left, val)
            else:
                if self._isred(node.left):
                    node = self._right_rotate(node)
                if val == node.val and not node.right:
                    return None
                if not self._isred(node.right) and not self._isred(node.right.left):
                    node = self._move_red_right(node)
                if val == node.val:
                    node.val = self._findmin(node.right).val
                    node.right = self._delete_min(node.right)
                else:
                    node.right = self._delete(node.right, val)
            return self._balance(node)
        
        def _findmin(self, node):
            if self.left:
                return self._findmin(node.left)
            return node

    代码自己写的,如有错误请指正。

    参考:

    Algorithm(4th version)

    数据结构与算法——Java语言描述

  • 相关阅读:
    发现一个奇怪的问题: 不能把文件取名为 con
    博客园新购服务器硬件配置
    [重要新功能]团队Blog
    [庆祝]博客园已迁至新服务器
    [公告]博客园论坛开放注册
    博客园出现了奇怪的cookie问题
    [公告]博客园聊天室试运行
    [重发]请为你喜欢的博客园杂志的名字投上一票
    [征询意见]准备采用“创作共用”协议保护大家的原创作品
    博客园期刊制作小组Blog开通
  • 原文地址:https://www.cnblogs.com/tianyadream/p/12543823.html
Copyright © 2011-2022 走看看