zoukankan      html  css  js  c++  java
  • 决策树

    # 生成决策树
    from math import log
    import operator
    import pickle
    
    
    def createDataSet():
        dataSet = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return dataSet, labels
    
    
    def calcShannonEnt(data_set):
        num_entries = len(data_set)  # 行数
        label_counts = {}
        for feat_vec in data_set:  # 循环每行
            current_label = feat_vec[-1]  # 矩阵最后一列为标签
            if current_label not in label_counts.keys():  # 若字典内没标签
                label_counts[current_label] = 0  # 将标签:0,以key:value 形式存进字典中
            label_counts[current_label] += 1  # 字典内有该标签,则该标签count+1
        shannon_ent = 0.0
        for key in label_counts:
            prob = float(label_counts[key]) / num_entries  # 求出每个标签的概率
            shannon_ent -= prob * log(prob, 2)  # 熵 = 所有元素的-(p(x1)log(x1,2))求和
        return shannon_ent  # 返回熵
    
    
    def splitDataSet(data_set, axis, value):
        '''
        涮选出根据axis与value划分后的列表
        :param data_set: 数据集
        :param axis: 根据axis去除列
        :param value: 根据value去除行
        :return:
        '''
        ret_data_set = []  # 创建新列表
        for feat_vec in data_set:  # 循环每行
            if feat_vec[axis] == value:
                reduced_feat_vec = feat_vec[:axis]
                reduced_feat_vec.extend(feat_vec[axis + 1:])  # 这两步等于删除feat_vec[axis],但是为了不影响原始数据,所以分两步切片合并
                ret_data_set.append(reduced_feat_vec)
        return ret_data_set  # 返回划分后的数据集
    
    
    def chooseBestFeatureToSplit(data_set):  # 找到最大熵对应的索引
        num_features = len(data_set[0]) - 1  # data_set最后一个元素已用
        base_entropy = calcShannonEnt(data_set)  # 整个数据集的熵,保存最初的无序度量值
        best_info_gain = 0.0
        best_feature = -1
        for i in range(num_features):
            feat_list = [example[i] for example in data_set]  # 利用列表生成式产生新列表
            unique_vals = set(feat_list)  # 利用集合去重
            new_entropy = 0.0
            for value in unique_vals:  # 遍历去重后的唯一属性值
                sub_data_set = splitDataSet(data_set, i, value)  # 已唯一属性值为value划分出数据集
                prob = len(sub_data_set) / float(len(data_set))  # 计算该数据集的概率
                new_entropy += prob * calcShannonEnt(sub_data_set)  # 计算熵
            info_gain = base_entropy - new_entropy
            if info_gain > best_info_gain:  # 判断是否大于最佳熵
                best_info_gain = info_gain
                best_feature = i
        return best_feature  # 返回最大熵对应的索引
    
    
    def majorityCnt(class_list):
        class_count = {}
        for vote in class_list:
            if vote not in class_count.keys():
                class_count[vote] = 0
            class_count[vote] += 1
            sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        return sorted_class_count[0][0]  # 返回次数最多的标签
    
    
    def createTree(data_set, labels):
        class_list = [example[-1] for example in data_set]  # 拿到各个标签的列表
        if class_list.count(class_list[0]) == len(class_list):  # 所有类标签完全相同时
            return class_list[0]  # 返回类标签,return是结束本次递归循环并返回值
        if len(data_set[0]) == 1:  # 遍历玩所有标签,仍然没有把数据集划分为唯一类别
            return majorityCnt(class_list)  # 返回次数最多的标签
        best_feat = chooseBestFeatureToSplit(data_set)  # 找到最大熵对应的索引
        best_feat_label = labels[best_feat]  # 最合适的标签
        my_tree = {best_feat_label: {}}  # 创建tree字典
        del (labels[best_feat])
        feat_values = [example[best_feat] for example in data_set]  # 得到最大熵对应列的所有属性值
        unique_vals = set(feat_values)  # 去除相同属性值
        for value in unique_vals:
            sub_labels = labels[:]  # 保证for期间
            my_tree[best_feat_label][value] = createTree(splitDataSet(data_set, best_feat, value),
                                                         sub_labels)  # 给字典增加键为value,值为返回值的新键值
        return my_tree
    
    
    def classify(input_tree, feat_tabels, test_vec):
        first_str = list(input_tree.keys())[0]  # 决策树顶端key
        second_dict = input_tree[first_str]  # 第二层字典
        feat_index = feat_tabels.index(first_str)  # 最佳分类对应的索引
        for key in second_dict.keys():
            if test_vec[feat_index] == key:
                if type(second_dict[key]).__name__ == 'dict':
                    class_label = classify(second_dict[key], feat_tabels, test_vec)
                else:
                    class_label = second_dict[key]
        return class_label
    
    
    def storeTree(input_tree, filename):
        with open(filename, 'wb') as fw:
            pickle.dump(input_tree, fw)
            fw.close()
    
    
    def grabTree(filename):
        with open(filename, 'rb') as fr:
            return pickle.load(fr)
    

     绘制决策树

    import matplotlib.pyplot as plt
    # 把决策树的字典放入createPlot(dict) 即可
    
    decision_node = dict(boxstyle='sawtooth', fc='0.8')  # 填文本底色
    leaf_node = dict(boxstyle='round4', fc='0.8')
    arrow_args = dict(arrowstyle='<-')
    
    
    def plotNode(node_text, center_pt, parent_pt, node_type):  # 绘制带箭头的注解
        createPlot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
                                xytext=center_pt, textcoords='axes fraction', va='center',
                                ha='center', bbox=node_type, arrowprops=arrow_args)
    
    
    def getNumLeafs(my_tree):  # 找出决策树中没分叉的数量
        num_leafs = 0
        first_str = list(my_tree.keys())[0]  # 取出决策树字典中的最佳值
        second_dict = my_tree[first_str]  # 取出第二个字典
        for key in second_dict.keys():  # 遍历字典的key
            if type(second_dict[key]).__name__ == 'dict':  # 如果key对应的value是字典,则递归继续分
                num_leafs += getNumLeafs(second_dict[key])
            else:
                num_leafs += 1  # 不是则+1
        return num_leafs
    
    
    def getTreeDepth(my_tree):  # 找出决策树的层数,第一层不算
        max_depth = 0
        first_str = list(my_tree.keys())[0]  # 取出决策树字典中的最佳值
        second_dict = my_tree[first_str]  # 取出第二个字典
        for key in second_dict.keys():  # 遍历字典的key
            if type(second_dict[key]).__name__ == 'dict':
                this_depth = 1 + getTreeDepth(second_dict[key])
            else:
                this_depth = 1
            if this_depth > max_depth:
                max_depth = this_depth
        return max_depth
    
    
    def retrieveTree(i):
        list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                         {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yse'}}, 1: 'no'}}}}]
        return list_of_trees[i]
    
    
    def plotMidText(cntr_pt, parent_pt, txt_string):    #   在两个节点之间的线上写字
        xMid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
        yMid = (parent_pt[1] - cntr_pt[1]) / 2 + cntr_pt[1]
        createPlot.ax1.text(xMid, yMid, txt_string)
    
    
    def createPlot(in_tree):
        fig = plt.figure(1, facecolor='white')
        fig.clf()  # 清空绘图区
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
        global plotTree_w
        global plotTree_d
        plotTree_w = float(getNumLeafs(in_tree))
        plotTree_d = float(getTreeDepth(in_tree))
        plotTree.xOff = -0.5 / plotTree_w
        plotTree.yOff = 1.0
        plotTree(in_tree, (0.5, 1.0), '')
        plt.show()
    
    
    def plotTree(my_tree, parent_pt, node_txt):     # 画树
        num_leafs = getNumLeafs(my_tree)
        depth = getTreeDepth(my_tree)
        first_str = list(my_tree.keys())[0]
        cntr_pt = (plotTree.xOff + (1.0 + float(num_leafs)) / 2.0 / plotTree_w, plotTree.yOff)
        plotMidText(cntr_pt, parent_pt, node_txt)
        plotNode(first_str, cntr_pt, parent_pt, decision_node)
        second_dict = my_tree[first_str]
        plotTree.yOff = plotTree.yOff - 1.0 / plotTree_d    # 减少y的值,将树的总深度平分,每次减少移动一点(向下,因为树是自顶向下画的)
        for key in second_dict.keys():
            if type(second_dict[key]).__name__ == 'dict':
                plotTree(second_dict[key], cntr_pt, str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0 / plotTree_w
                plotNode(second_dict[key], (plotTree.xOff, plotTree.yOff), cntr_pt, leaf_node)
                plotMidText((plotTree.xOff, plotTree.yOff), cntr_pt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0 / plotTree_d
    

      

  • 相关阅读:
    ELK+FileBeat 开源日志分析系统搭建-Centos7.8
    ORACLE转换时间戳方法(1546272000)
    由Swap故障引起的ORA-01034: ORACLE not available ORA-27102: out of memory 问题
    数据库设计规范
    数据库字段备注信息声明语法 CDL (Comment Declaration Language)
    渐进式可扩展数据库模型(Progressive Extensible Database Model, pedm)
    使用 ES6 的 Promise 对象和 Html5 的 Dialog 元素,模拟 JS 的 alert, confirm, prompt 方法的阻断执行功能。
    在sed中引入shell变量的四种方法
    参考文献中的[EB/OL]表示什么含义?
    优秀看图软件 XnViewMP v0.97.1 / XnView v2.49.4 Classic
  • 原文地址:https://www.cnblogs.com/luck-L/p/9152844.html
Copyright © 2011-2022 走看看