zoukankan      html  css  js  c++  java
  • Python实现决策树

    Python实现无剪枝的决策树

    import math
    import numpy as np
    import pydot
    
    """
    自己实现决策树:无剪枝
    """
    
    
    class Node:
        def __init__(self, data):
            self.attr = -1  # 当前节点上的划分属性
            self.sons = {}  # 该结点的儿子结点
            self.ans = None  # 该结点的答案,只有叶子节点才有,通过此属性判断是否是叶子节点
            self.data = data  # 数据的下标列表
    
        def is_leaf(self):
            return self.ans is not None
    
        def __str__(self):
            return "node_cnt:{} {}".format(len(self.data), "ans:%d" % self.ans if self.ans else "")
    
    
    class DicisionTree:
        def __init__(self, x, y, creteria="id3"):
            self.x = np.array(x)
            self.y = np.array(y)
            if self.x.dtype not in (np.int, np.int64) or self.y.dtype not in (np.int, np.int64):
                raise Exception("这里的决策树只能处理整数类型")
            self.num_class = len(set(y))
            self.num_attr = len(x[0])
            self.creteria = creteria
            self.root = self._build(list(range(len(self.y))), list(range(self.num_attr)))
    
        def _split(self, x, attr):
            # 将数据集x按照属性attr的取值分开
            xset = {}
            for i in x:
                v = self.x[i][attr]
                if v not in xset:
                    xset[v] = []
                xset[v].append(i)
            return xset
    
        def _buildTable(self, x, attrs):
            # 按照属性、属性取值、类别三个维度统计元素个数
            table = [{} for _ in range(self.num_attr)]
            for i in x:
                for attr in range(self.num_attr):
                    v, c = self.x[i][attr], self.y[i]
                    if v not in table[attr]:
                        table[attr][v] = {}
                    if c not in table[attr][v]:
                        table[attr][v][c] = 0
                    table[attr][v][c] += 1
            return table
    
        def _id3(self, table):
            # 根据表求id3的值
            aloga = 0
            rlogr = 0
            for v in table:
                r = 0
                for c in table[v]:
                    aloga += table[v][c] * math.log(table[v][c])
                    r += table[v][c]
                rlogr += r * math.log(r)
            return aloga - rlogr
    
        def _c45(self, table, tlogt, slogs):
            aloga = 0
            rlogr = 0
            for v in table:
                r = 0
                for c in table[v]:
                    aloga += table[v][c] * math.log(table[v][c])
                    r += table[v][c]
                rlogr += r * math.log(r)
            return (tlogt - aloga) / (1 if len(table) == 1 else  rlogr - slogs)
    
        def _gini(self, table):
            gain = 0
            for v in table:
                a2 = 0
                r = 0
                for c in table[v]:
                    a2 += table[v][c] ** 2
                    r += table[v][c]
                gain += a2 / r
            return gain
    
        # 只有C45用到了slogs和tlogt,id3和gini都没有用到
        def _c45_tlogt(self, data):
            slogs = len(data) * math.log(len(data))
            cnt = {}
            for i in data:
                y = self.y[i]
                if y not in cnt:
                    cnt[y] = 0
                cnt[y] += 1
            tlogt = 0
            for i in cnt:
                tlogt += cnt[i] * math.log(cnt[i])
            return tlogt, slogs
    
        def _selectAttr(self, x, attrs):
            # 选择属性
            t = self._buildTable(x, attrs)
            ans_attr, ans_gain = None, -0xfffff
            for attr in attrs:
                if self.creteria == "id3":
                    gain = self._id3(t[attr])
                elif self.creteria == "gini":
                    gain = self._gini(t[attr])
                elif self.creteria == "c45":
                    tlogt, slogs = self._c45_tlogt(x)
                    gain = self._c45(t[attr], tlogt=tlogt, slogs=slogs)
                else:
                    raise Exception("unkown creterial{},the 3 suported creteria are id3,c45,gini".format(self.creteria))
                if ans_gain is None or gain > ans_gain:
                    ans_gain = gain
                    ans_attr = attr
            return ans_attr
    
        def _allsame(self, array):
            x = array[0]
            for i in array:
                if x != i: return False
            return True
    
        def _build(self, data, attrs):
            node = Node(data)
            if self._allsame(self.y[data]) or not attrs:
                node.ans = self.y[data[0]]
                return node
            node.attr = self._selectAttr(data, attrs)
            # print(node.attr, "selected attr")
            attrs.remove(node.attr)
            xset = self._split(data, node.attr)
            for v in xset.keys():
                node.sons[v] = self._build(xset[v], attrs)
            attrs.append(node.attr)  # 将属性复原还给父结点
            return node
    
        def predict(self, data_x):
            def _predict_one(x):
                node = self.root
                while not node.is_leaf():
                    value = x[node.attr]
                    if value in node.sons:
                        node = node.sons[value]
                    else:
                        break
                if node.is_leaf():
                    return node.ans
                return None  # 无答案
    
            return np.array(list(map(_predict_one, data_x)))
    
        def get_node_count(self):
            def dfs(node):
                cnt = 1
                for i in node.sons:
                    cnt += dfs(node.sons[i])
                return cnt
    
            return dfs(self.root)
    
        def export_graphviz(self):
            g = pydot.Dot(graph_type="digraph")
    
            def dfs(node, parent, label):
                if hasattr(dfs, "nodeid"):
                    dfs.nodeid += 1
                else:
                    dfs.nodeid = 0
                me = pydot.Node(str(dfs.nodeid), label=str(node))
                g.add_node(me)
                if parent is not None:
                    g.add_edge(pydot.Edge(parent, me, label=label))
                for k, v in node.sons.items():
                    dfs(v, me, "attr{}={}".format(node.attr, k))
    
            dfs(self.root, None, "")
            g.write("haha.jpg", prog='dot', format="jpg")
    
    
    if __name__ == '__main__':
        # 增益函数的选取:id3,gini,c45
        gain_f = "id3"
        x = np.array([[0, 3, 0], [0, 2, 1], [1, 1, 2], [1, 2, 2], [2, 3, 0], [2, 1, 1]])
        y = np.array([0, 0, 1, 2, 0, 1])
        tree = DicisionTree(x, y, gain_f)
        ans = tree.predict(x)
        print(ans)
        cnt = np.count_nonzero(y == ans)
        print('正确的个数,正确率', cnt, cnt / len(x))
        print('不确定的个数', len([1 for i in range(len(ans)) if ans[i] == 'not found']))
        print('结点总数', tree.get_node_count())
    
    
  • 相关阅读:
    Microsoft Enterprise Library 5.0 系列(二) Cryptography Application Block (初级)
    Microsoft Enterprise Library 5.0 系列(五) Data Access Application Block
    Microsoft Enterprise Library 5.0 系列(八) Unity Dependency Injection and Interception
    Microsoft Enterprise Library 5.0 系列(九) Policy Injection Application Block
    Microsoft Enterprise Library 5.0 系列(三) Validation Application Block (高级)
    软件研发打油诗祝大家节日快乐
    从挖井的故事中想到开发管理中最容易忽视的几个简单道理
    ITIL管理思想的执行工具发布
    管理类软件设计“渔”之演化
    20070926日下午工作流与ITILQQ群 事件管理 讨论聊天记录
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/8488124.html
Copyright © 2011-2022 走看看