zoukankan      html  css  js  c++  java
  • (三)用go实现平衡二叉树

    本篇,我们用go简单的实现平衡二叉查找树。具体原理参考大佬博客即可:AVL树(一)之 图文解析 和 C语言的实现

    1.节点定义

    type AVLNode struct{
        data int
        height int
        left, right *AVLNode
    }
    

    2.树的遍历

    // 前序遍历
    func PreTraverse(p *AVLNode) {
        if p == nil {
            return 
        }
        fmt.Printf("%d:%d ", p.data, p.height)
        if p.left != nil {
            PreTraverse(p.left)
        }
        if p.right != nil {
            PreTraverse(p.right)
        }
    }
    
    // 中序遍历
    func InTraverse(p *AVLNode) {
        if p == nil {
            return 
        }
        if p.left != nil {
            InTraverse(p.left)
        }
        fmt.Printf("%d ", p.data)
        if p.right != nil {
            InTraverse(p.right)
        }
    }
    
    // 后序遍历
    func PostTraverse(p *AVLNode) {
        if p == nil {
            return 
        }
        if p.left != nil {
            PostTraverse(p.left)
        }
        if p.right != nil {
            PostTraverse(p.right)
        }
        fmt.Printf("%d ", p.data)
    }
    

    3.树的旋转

    // LL的旋转
    func ll_rotate(k2 *AVLNode) *AVLNode {
        var k1 *AVLNode = k2.left
        k2.left = k1.right
        k1.right = k2
    
        k2.height = max(height(k2.left), height(k2.right)) + 1
        k1.height = max(height(k1.left), k2.height) + 1
    
        return k1
    }
    
    // RR的旋转
    func rr_rotate(k1 *AVLNode) *AVLNode {
        var k2 *AVLNode = k1.right
        k1.right = k2.left
        k2.left = k1
    
        k1.height = max(height(k1.left), height(k1.right)) + 1
        k2.height = max(height(k2.right), k1.height) + 1
    
        return k2
    }
    
    // LR的旋转
    func lr_rotate(k3 *AVLNode) *AVLNode {
        k3.left = rr_rotate(k3.left)
        return ll_rotate(k3)
    }
    
    // RL的旋转
    func rl_rotate(k1 *AVLNode) *AVLNode {
        k1.right = ll_rotate(k1.right)
        return rr_rotate(k1)
    }
    

    4.插入节点

    // 插入节点
    func Add(p *AVLNode, data int) *AVLNode {
        if p == nil {
            p = new(AVLNode)
            p.data = data
            p.height = 1
            return p
        }
    
        if data < p.data {
            p.left = Add(p.left, data)
            if height(p.left) - height(p.right) == 2 {
                if data > p.left.data {
                    fmt.Println("lr")
                    p = lr_rotate(p)
                } else {
                    fmt.Println("ll")
                    p = ll_rotate(p)
                }
            }
        } else if data > p.data {
            p.right = Add(p.right, data)
            if height(p.right) - height(p.left) == 2{
                if data > p.right.data {
                    fmt.Println("rr")
                    p = rr_rotate(p)
                } else {
                    fmt.Println("rl")
                    p = rl_rotate(p)
                }
            }
        } else {
            fmt.Println("Add fail: not allowed same data!")
        }
    
        p.height = max(height(p.left), height(p.right)) + 1
        fmt.Printf("节点:%d, 高度:%d
    ", p.data, p.height)
    
        return p
    }
    

    5.查询节点

    // 查询节点
    func Find(p *AVLNode, data int) *AVLNode {
        if p.data == data {
            return p
        } else if data < p.data {
            if p.left != nil {
                return Find(p.left, data)
            }
            return nil
        } else {
            if p.right != nil {
                return Find(p.right, data)
            }
            return nil
        }
    }
    
    // 最大节点
    func maxNode(p *AVLNode) *AVLNode {
        if p == nil {
            return nil
        }
        for p.right != nil {
            p = p.right
        }
        return p
    }
    
    // 最小节点
    func minNode(p *AVLNode) *AVLNode {
        if p == nil {
            return nil
        }
        for p.left != nil {
            p = p.left
        }
        return p
    }
    

    6.删除节点

    // 删除节点
    func Delete(p *AVLNode, data int) *AVLNode {
        node := Find(p, data)
        if node != nil {
            return delete(p, node)
        }
        return nil
    }
    
    func delete(p, node *AVLNode) *AVLNode {
        if node.data < p.data {
            p.left = delete(p.left, node)
            if height(p.right) - height(p.left) == 2 {
                if height(p.right.right) > height(p.right.left) {
                    p = rr_rotate(p)
                } else {
                    p = rl_rotate(p)
                }
            }
        } else if node.data > p.data {
            p.right = delete(p.right, node)
            if height(p.left) - height(p.right) == 2 {
                if height(p.left.right) > height(p.left.left) {
                    p = lr_rotate(p)
                } else {
                    p = ll_rotate(p)
                }
            }
        } else {
            // 左右孩子都非空
            if (p.left != nil) && (p.right != nil) {
                if height(p.left) > height(p.right) {
                    var max *AVLNode = maxNode(p.left)
                    p.data = max.data
                    p.left = delete(p.left, max)
                } else {
                    var min *AVLNode = minNode(p.right)
                    p.data = min.data
                    p.right = delete(p.right, min)
                }
            } else {
                if p.left != nil {
                    p = p.left
                } else {
                    p = p.right
                }
            }
        }
    
        if p != nil {
            p.height = max(height(p.left), height(p.right)) + 1
        }
    
        return p
    
    }
    

    7.完整代码

    package main
    
    import (
        "fmt"
    )
    
    type AVLNode struct{
        data int
        height int
        left, right *AVLNode
    }
    
    func max(a, b int) int {
        if a > b {
            return a
        }
        return b
    }
    
    func height(p *AVLNode) int {
        if p != nil {
            return p.height
        }
        return 0
    } 
    
    // 前序遍历
    func PreTraverse(p *AVLNode) {
        if p == nil {
            return 
        }
        
        fmt.Printf("%d:%d ", p.data, p.height)
        if p.left != nil {
            PreTraverse(p.left)
        }
        if p.right != nil {
            PreTraverse(p.right)
        }
    }
    
    // 中序遍历
    func InTraverse(p *AVLNode) {
        if p == nil {
            return 
        }
        
        if p.left != nil {
            InTraverse(p.left)
        }
        fmt.Printf("%d ", p.data)
        if p.right != nil {
            InTraverse(p.right)
        }
    }
    
    // 后序遍历
    func PostTraverse(p *AVLNode) {
        if p == nil {
            return 
        }
        
        if p.left != nil {
            PostTraverse(p.left)
        }
        if p.right != nil {
            PostTraverse(p.right)
        }
        fmt.Printf("%d ", p.data)
    }
    
    
    // LL的旋转
    func ll_rotate(k2 *AVLNode) *AVLNode {
        var k1 *AVLNode = k2.left
        k2.left = k1.right
        k1.right = k2
    
        k2.height = max(height(k2.left), height(k2.right)) + 1
        k1.height = max(height(k1.left), k2.height) + 1
    
        return k1
    }
    
    // RR的旋转
    func rr_rotate(k1 *AVLNode) *AVLNode {
        var k2 *AVLNode = k1.right
        k1.right = k2.left
        k2.left = k1
    
        k1.height = max(height(k1.left), height(k1.right)) + 1
        k2.height = max(height(k2.right), k1.height) + 1
    
        return k2
    }
    
    // LR的旋转
    func lr_rotate(k3 *AVLNode) *AVLNode {
        k3.left = rr_rotate(k3.left)
        return ll_rotate(k3)
    }
    
    // RL的旋转
    func rl_rotate(k1 *AVLNode) *AVLNode {
        k1.right = ll_rotate(k1.right)
        return rr_rotate(k1)
    }
    
    // 插入节点
    func Add(p *AVLNode, data int) *AVLNode {
        if p == nil {
            p = new(AVLNode)
            p.data = data
            p.height = 1
            return p
        }
    
        if data < p.data {
            p.left = Add(p.left, data)
            if height(p.left) - height(p.right) == 2 {
                if data > p.left.data {
                    fmt.Println("lr")
                    p = lr_rotate(p)
                } else {
                    fmt.Println("ll")
                    p = ll_rotate(p)
                }
            }
        } else if data > p.data {
            p.right = Add(p.right, data)
            if height(p.right) - height(p.left) == 2{
                if data > p.right.data {
                    fmt.Println("rr")
                    p = rr_rotate(p)
                } else {
                    fmt.Println("rl")
                    p = rl_rotate(p)
                }
            }
        } else {
            fmt.Println("Add fail: not allowed same data!")
        }
    
        p.height = max(height(p.left), height(p.right)) + 1
        fmt.Printf("节点:%d, 高度:%d
    ", p.data, p.height)
    
        return p
    }
    
    // 查询节点
    func Find(p *AVLNode, data int) *AVLNode {
        if p.data == data {
            return p
        } else if data < p.data {
            if p.left != nil {
                return Find(p.left, data)
            }
            return nil
        } else {
            if p.right != nil {
                return Find(p.right, data)
            }
            return nil
        }
    }
    
    // 最大节点
    func maxNode(p *AVLNode) *AVLNode {
        if p == nil {
            return nil
        }
        for p.right != nil {
            p = p.right
        }
        return p
    }
    
    // 最小节点
    func minNode(p *AVLNode) *AVLNode {
        if p == nil {
            return nil
        }
        for p.left != nil {
            p = p.left
        }
        return p
    }
        
    // 删除节点
    func Delete(p *AVLNode, data int) *AVLNode {
        node := Find(p, data)
        if node != nil {
            return delete(p, node)
        }
        return nil
    }
    
    func delete(p, node *AVLNode) *AVLNode {
        if node.data < p.data {
            p.left = delete(p.left, node)
            if height(p.right) - height(p.left) == 2 {
                if height(p.right.right) > height(p.right.left) {
                    p = rr_rotate(p)
                } else {
                    p = rl_rotate(p)
                }
            }
        } else if node.data > p.data {
            p.right = delete(p.right, node)
            if height(p.left) - height(p.right) == 2 {
                if height(p.left.right) > height(p.left.left) {
                    p = lr_rotate(p)
                } else {
                    p = ll_rotate(p)
                }
            }
        } else {
            // 左右孩子都非空
            if (p.left != nil) && (p.right != nil) {
                if height(p.left) > height(p.right) {
                    var max *AVLNode = maxNode(p.left)
                    p.data = max.data
                    p.left = delete(p.left, max)
                } else {
                    var min *AVLNode = minNode(p.right)
                    p.data = min.data
                    p.right = delete(p.right, min)
                }
            } else {
                if p.left != nil {
                    p = p.left
                } else {
                    p = p.right
                }
            }
        }
    
        if p != nil {
            p.height = max(height(p.left), height(p.right)) + 1
        }
    
        return p
    
    }
    
    
    func main() {
        //num := []int{50, 30, 20, 25, 70, 90, 100}  
        num := []int{3, 2, 1, 4, 5, 6, 7, 16, 15, 14, 13, 12, 11, 10, 8, 9}
    
        var root *AVLNode
        for _, v := range num {
            fmt.Printf("插入节点:%d
    ", v)
            root = Add(root, v)
        }
    
        fmt.Println("前序遍历:")
        PreTraverse(root)
        fmt.Printf("
    ")
    
        fmt.Println("中序遍历:")
        InTraverse(root)
        fmt.Printf("
    ")
    
        fmt.Println("后序遍历:")
        PostTraverse(root)
        fmt.Printf("
    ")
    
        avlnode := Find(root, 60)
        if avlnode != nil {
            fmt.Println("查询结果:")
            fmt.Printf("节点:%d 左子节点:%d 右子节点:%d
    ", avlnode.data, avlnode.left.data, avlnode.right.data)
        }
    
        root = Delete(root, 8)
        fmt.Println("删除后前序遍历:")
        PreTraverse(root)
        fmt.Printf("
    ")
    
        fmt.Println("删除后中序遍历:")
        InTraverse(root)
        fmt.Printf("
    ")
    
    
    }
    
    
  • 相关阅读:
    优先队列
    Problem W UVA 662 二十三 Fast Food
    UVA 607 二十二 Scheduling Lectures
    UVA 590 二十一 Always on the run
    UVA 442 二十 Matrix Chain Multiplication
    UVA 437 十九 The Tower of Babylon
    UVA 10254 十八 The Priest Mathematician
    UVA 10453 十七 Make Palindrome
    UVA 10163 十六 Storage Keepers
    UVA 1252 十五 Twenty Questions
  • 原文地址:https://www.cnblogs.com/qxcheng/p/15480801.html
Copyright © 2011-2022 走看看