zoukankan      html  css  js  c++  java
  • python

      算法流程参考了《统计学习与方法》

    import numpy as np
    import pandas as pd
    from graphviz import Digraph
    
    
    class BaseTree(object):
        def __init__(self, feature, label, title=None):
            """
            树基本属性
            :param feature: 特征列
            :param label: 目标列
            :param title: 标题列
            """
            self.root = {}
            self.feature = feature
            self.label = label
            self.title = title
            self.n_classes = tuple(set(label))
            self.n_features = len(feature)
    
            self._build_tree_(self.root, self.label, self.feature, self.title)
    
        def _build_tree_(self, n, label, feature, title):
            """
            递归构建决策树
            :param n: 对每一层来说该参数是树的当前节点 字典树
            :param label: 每一层递归中目标列会由于上一层中做出的选择而不断减小
            :param feature: 每一层递归中特征也会由于前一层中做出的选择而不断减小 对于一条路径来说 会减少特征列
            :param title: 标题原理同 label
            :return: 
            """
            if feature is not None:
                gain_feature = self._solve_gain_(label, feature, tuple(set(label)), len(feature))
                select_feature = np.argmax(gain_feature)
                selected_feature = feature[:, select_feature]
                states = set(selected_feature)
                n[title[select_feature]] = {}
                feature = np.delete(feature, select_feature, axis=1)
                title_ = np.delete(title, select_feature)
                for i, _ in enumerate(states):
                    inx = np.where(selected_feature == _)[0]
                    label_ = label[inx]
                    if len(set(label_)) == 1:
                        n[title[select_feature]][str(label_[0])] = {}
                        continue
                    feature_ = feature[inx]
                    self._build_tree_(n[title[select_feature]], label_, feature_, title_)
    
        @staticmethod
        def _solve_entropy_(label, n_classes, n_features):
            """
            计算经验熵
            :param label: 
            :param n_classes: 
            :param n_features: 
            :return: 
            """
            sigma = 0.
            for _ in n_classes:
                p = len(np.where(label == _)[0]) / n_features
                sigma += (p * np.log2(p) if p != 0 else 0)
            return -sigma
    
        def _solve_conditional_entropy_(self, label, feature, n_classes):
            """
            计算条件熵
            :param label: 
            :param feature: 
            :param n_classes: 
            :return: 
            """
            group_states = [tuple(set(_)) for _ in feature.transpose()]
            n_features = len(feature)
            ans = []
            for i, branch in enumerate(group_states):
                res = 0.
                for state in branch:
                    inx = np.where(feature[:, i] == state)[0]
                    n_f = len(inx)
                    sigma = self._solve_entropy_(label[inx], n_classes, n_f)
                    res += n_f / n_features * sigma
                ans.append(res)
            return np.array(ans)
    
        def _solve_gain_(self, label, feature, n_classes, n_features):
            """
            计算信息增益
            :param label: 
            :param feature: 
            :param n_classes: 
            :param n_features: 
            :return: 
            """
            return self._solve_entropy_(label, n_classes, n_features)
                   - self._solve_conditional_entropy_(label, feature, n_classes)
    
        def get_root(self):
            """
            获取决策树
            :return: 
            """
            return self.root
    
    
    class DecisionTree(object):
        def __init__(self, title=None):
            self.tree = None
            self.title = title
            self.que = []
            self.nums = []
            self.count = 1
    
        def fit(self, X, y):
            """
            拟合接口
            :param X: 数据集的特征
            :param y: 数据集的目标
            :return: 
            """
            self.tree = BaseTree(X, y, self.title)
    
        def predict(self, X):
            pass
    
        def decision_tree_struct(self):
            """
            获取决策树
            :return: 
            """
            return self.tree.get_root()
    
        def _travl_dict_(self, dot, node):
            """
            层序遍历字典树以生成 graph
            :param dot: 
            :param node: 
            :return: 
            """
            for _ in node.keys():
                s = str(self.nums.pop(0) if self.nums else '0')
                self.que.append(node[_])
                for i, k in enumerate(node[_].keys()):
                    dot.edge(_ + '_' + s, k + '_' + str(self.count),
                             label='Yes' if i & 1 else 'No')
                    self.nums.append(self.count)
                    self.count += 1
            if self.que:
                self._travl_dict_(dot, self.que.pop(0))
    
        def generate_graph(self):
            """
            生成决策树图
            :return: 
            """
            dot = Digraph(name='tree', node_attr={'shape': 'circle'}, format='png')
            self._travl_dict_(dot, self.tree.get_root())
            dot.render('decision.gv', view=True)
    
    
    data = pd.read_csv('data.csv', encoding='gbk')
    y = data['index'].astype(np.int)
    train_x = data.drop(['index', 'name'], axis=1)
    for _ in train_x.keys():
        train_x[_] = train_x[_].map({'是': 1, '否': 0})
    title = list(train_x.keys())
    
    dt = DecisionTree(title)
    X = train_x.values.reshape((-1, len(title)))
    dt.fit(X, y.values)
    d = dt.decision_tree_struct()
    print(d)
    dt.generate_graph()
    

      因为没有设置 graphviz 中文字体,所以无法正常显示,使用 index 代替了。
      测试数据:

      graphviz 查看生成的决策树结构:

  • 相关阅读:
    tab下图片要求
    segmentController
    下拉加载更多
    获取用户信息
    时间选择器
    JavaScript标准库之 ----  Object
    JavaScript标准库之——JSON
    JavaScript标准库之
    post网络请求坑
    构造一个简单的Linux系统MenuOS
  • 原文地址:https://www.cnblogs.com/darkchii/p/13207599.html
Copyright © 2011-2022 走看看