zoukankan      html  css  js  c++  java
  • 机器学习决策树ID3算法,python实现代码

    机器学习决策树ID3算法,python实现代码

    看到techflow介绍ID3算法,中间有代码示例。代码尝试执行力下,发现有错误。
    https://www.cnblogs.com/techflow/p/12935130.html
    经过整理,错误排查完毕。分享出来

    import numpy as np
    import math
    from collections import Counter, defaultdict
    
    # 保证每次运行的结果一致
    np.random.seed(100)
    
    def create_data():
        X1 = np.random.rand(50, 1)*100
        X2 = np.random.rand(50, 1)*100
        X3 = np.random.rand(50, 1)*100
        
        def f(x):
            return 2 if x > 70 else 1 if x > 40 else 0
        
        y = X1 + X2 + X3
        Y = y > 150
        Y = Y + 0
        r = map(f, X1)
        X1 = list(r)
        
        r = map(f, X2)
        X2 = list(r)
        
        r = map(f, X3)
        X3 = list(r)
        x = np.c_[X1, X2, X3, Y]
        return x, ['courseA', 'courseB', 'courseC']
    
    
    
    def calculate_info_entropy(dataset):
        n = len(dataset)
        # 我们用Counter统计一下Y的数量
        labels = Counter(dataset[:, -1])
        entropy = 0.0
        # 套用信息熵公式
        for k, v in labels.items():
            prob = v / n
            entropy -= prob * math.log(prob, 2)
        return entropy
    
    def split_dataset(dataset, idx):
       # idx是要拆分的特征下标
        splitData = defaultdict(list)
        for data in dataset:
           # 这里删除了idx这个特征的取值,因为用不到了
            splitData[data[idx]].append(np.delete(data, idx))
        for k, v in splitData.items():
            splitData[k] = np.array(v)
        return splitData.keys(), splitData.values()
    
    def choose_feature_to_split(dataset):
        n = len(dataset[0])-1
        m = len(dataset)
        # 切分之前的信息熵
        entropy = calculate_info_entropy(dataset)
        bestGain = 0.0
        feature = -1
        for i in range(n):
           # 根据特征i切分
            split_data = split_dataset(dataset, i)[1]
            new_entropy = 0.0
            # 计算切分后的信息熵
            for data in split_data:
                prob = len(data) / m
                new_entropy += prob * calculate_info_entropy(data)
            # 获取信息增益
            gain = entropy - new_entropy
            if gain > bestGain:
                bestGain = gain
                feature = i
        return feature
    
    def create_decision_tree(dataset, feature_names):
        dataset = np.array(dataset)
        counter = Counter(dataset[:, -1])
        # 如果数据集值剩下了一类,直接返回
        if len(counter) == 1:
            return dataset[0, -1]
        
        # 如果所有特征都已经切分完了,也直接返回
        if len(dataset[0]) == 1:
            return counter.most_common(1)[0][0]
        
        # 寻找最佳切分的特征
        fidx = choose_feature_to_split(dataset)
        fname = feature_names[fidx]
        
        node = {fname: {}}
        feature_names.remove(fname)
        
        # 递归调用,对每一个切分出来的取值递归建树
        vals, split_data = split_dataset(dataset, fidx)
        for val, data in zip(vals, split_data):
            node[fname][val] = create_decision_tree(data, feature_names[:])
        return node
    
    dataset, feature_names = create_data()
    tree = create_decision_tree(dataset, feature_names.copy())
    
    tree
    
    {'courseA': {0: {'courseC': {0: {'courseB': {0: 0, 1: 0, 2: 0}},
        1: 0,
        2: {'courseB': {0: 0, 1: 1, 2: 1}}}},
      1: {'courseC': {0: 0, 1: {'courseB': {0: 0, 1: 0}}, 2: 1}},
      2: {'courseC': {0: {'courseB': {0: 0, 1: 1, 2: 1}},
        1: {'courseB': {0: 1, 1: 1, 2: 1}},
        2: 1}}}}
    
    def classify(node, feature_names, data):
       # 获取当前节点判断的特征
        key = list(node.keys())[0]
        node = node[key]
        idx = feature_names.index(key)
        
        # 根据特征进行递归
        pred = None
        for key in node:
           # 找到了对应的分叉
            if data[idx] == key:
               # 如果再往下依然还有子树,那么则递归,否则返回结果
                if isinstance(node[key], dict):
                    pred = classify(node[key], feature_names, data)
                else:
                    pred = node[key]
                    
        # 如果没有对应的分叉,则找到一个分叉返回
        if pred is None:
            for key in node:
                if not isinstance(node[key], dict):
                    pred = node[key]
                    break
        return pred
    
    classify(tree, feature_names, [1,0,1])
    
    0
    
    classify(tree, feature_names, [2,2,1])
    
    1
    
    classify(tree, feature_names, [1,1,1])
    
    0
  • 相关阅读:
    Effective Java 19 Use interfaces only to define types
    Effective Java 18 Prefer interfaces to abstract classes
    Effective Java 17 Design and document for inheritance or else prohibit it
    Effective Java 16 Favor composition over inheritance
    Effective Java 15 Minimize mutability
    Effective Java 14 In public classes, use accessor methods, not public fields
    Effective Java 13 Minimize the accessibility of classes and members
    Effective Java 12 Consider implementing Comparable
    sencha touch SortableList 的使用
    sencha touch dataview 中添加 button 等复杂布局并添加监听事件
  • 原文地址:https://www.cnblogs.com/StitchSun/p/12937343.html
Copyright © 2011-2022 走看看