zoukankan      html  css  js  c++  java
  • ID3算法 决策树的生成(2)

    决策树的生成,采用ID3算法(也包含了C4.5算法),使用python实现,更新了tree的保存和图示。

    介绍摘自李航《统计学习方法》。

    5.2.3 信息增益比

    信息增益值的大小是相对于训练数据集而言的,并没有绝对意义。在分类问题困难时,也就是说在训练数据集的经验熵大的时候,信息增益值会偏大。反之,信息增益值会偏小。使用信息增益比(information gain ratio)可以对这一问题进行校正。这是特征选择的另一准则。

    定义5.3(信息增益比) 特征A对训练数据集D的信息增益比gR(D,A)定义为其信息增益g(D,A)与训练数据集D的经验熵H(D)之比:

    5.3.2 C4.5的生成算法

    C4.5算法与ID3算法相似,C4.5算法对ID3算法进行了改进。C4.5在生成的过程中,用信息增益比来选择特征。

    算法5.3(C4.5的生成算法)

    输入:训练数据集D,特征集A,阈值altε;

    输出:决策树T。

    (1)如果D中所有实例属于同一类Ck,则置T为单结点树,并将Ck作为该结点的类,返回T;

    (2)如果A=Ø,则置T为单结点树,并将D中实例数最大的类Ck作为该结点的类,返回T;

    (3)否则,按式(5.10)计算A中各特征对D的信息增益比,选择信息增益比最大的特征Ag

    (4)如果Ag的信息增益比小于阈值alt,则置T为单结点树,并将D中实例数最大的类Ck作为该结点的类,返回T;

    (5)否则,对Ag的每一可能值ai,依Ag=ai将D分割为子集若干非空Di,将Di中实例数最大的类作为标记,构建子结点,由结点及其子结点构成树T,返回T;

    (6)对结点i,以Di为训练集,以A-{Ag}为特征集,递归地调用步(1)~步(5),得到子树Ti,返回Ti

      1  # coding:utf-8
      2 import matplotlib.pyplot as plt
      3 import numpy as np
      4 import pylab
      5 
      6 def createDataSet(): #贷款申请样本数据表
      7     dataset = [["青年", "", "", "一般", "拒绝"],
      8                ["青年", "", "", "", "拒绝"],
      9                ["青年", "", "", "", "同意"],
     10                ["青年", "", "", "一般", "同意"],
     11                ["青年", "", "", "一般", "拒绝"],
     12                ["中年", "", "", "一般", "拒绝"],
     13                ["中年", "", "", "", "拒绝"],
     14                ["中年", "", "", "", "同意"],
     15                ["中年", "", "", "非常好", "同意"],
     16                ["中年", "", "", "非常好", "同意"],
     17                ["老年", "", "", "非常好", "同意"],
     18                ["老年", "", "", "", "同意"],
     19                ["老年", "", "", "", "同意"],
     20                ["老年", "", "", "非常好", "同意"],
     21                ["老年", "", "", "一般", "拒绝"],
     22                ]
     23     labels = ["年龄", "有工作", "有房子", "信贷情况"]
     24     return dataset, labels
     25 
     26 def getList(dataset,index=-1):#返回每层列表
     27     alist=[i[index] for i in dataset]
     28     aset=list(set(alist))
     29     acount=[alist.count(aset[j]) for j in range(len(aset))]
     30     return alist,aset,acount
     31 
     32 def getdH(account): #计算H(D)
     33     t=np.sum(account)
     34     return np.sum([-float(a)/t*np.log2(float(a)/t) for a in account])
     35 
     36 def getdaH(acount,ad): #计算H(D,A)
     37     t=np.sum(acount)
     38     return np.sum([[0 if j==0 else -a*float(j)/t/a*np.log2(float(j)/a) for j in b] for a,b in zip(acount,ad)])
     39 
     40 def gethaD(acount): #计算Ha(D)
     41     t=np.sum(acount)
     42     return np.sum([ -float(a)/t*np.log2(float(a)/t)  for a in acount])
     43 
     44 def getaH(dataset,index,c4_5=0): #计算g(D,A),若c4_5=1则采用信息增益比
     45     dlist,dset,dcount= getList(dataset,-1)
     46     hd=getdH(dcount)
     47     alist,aset,acount=getList(dataset,index)
     48     ad=[[[dlist[i] for i in range(len(dlist)) if dataset[i][index]==j].count(k) for k in dset] for j in aset]
     49     if c4_5:
     50         return 0 if gethaD(acount)==0 else (hd-getdaH(acount,ad))/gethaD(acount)
     51     else:
     52         return hd-getdaH(acount,ad)
     53 
     54 def ID3(dataset,labels,tree=[]):#ID3算法
     55     dlist,dset,dcount= getList(dataset,-1)
     56     if len(dset)<2 :
     57         tree.append([dset[0],0])
     58         return
     59     adlist=[[getaH(dataset,i),i] for i in range(len(dataset[0])-1)]
     60     t1= max(adlist,key=lambda x: x[0])
     61     tree.append([labels[t1[1]],2])
     62     alist,aset,acount=getList(dataset,t1[1])
     63     for a in aset:
     64         tree.append([a,1])
     65         ID3([i for i in dataset if i[t1[1]]==a],labels,tree)
     66     return tree
     67 
     68 def showT(tree):#根据Tree列表绘制图像
     69     import sys
     70     reload(sys)
     71     sys.setdefaultencoding('utf-8')
     72     pylab .mpl.rcParams['font.sans-serif'] = ['SimHei']
     73     fig1 = plt.figure(1, (6, 6))
     74     ax = fig1.add_axes([0, 0, 1, 1], frameon=False, aspect=1.)
     75     x,y=0.5,0.85
     76     for i in range(len(tree)):
     77         if tree[i][1]==2:
     78              fig1.text(x,y, tree[i][0],ha="center",size=21,bbox=dict(boxstyle="square", fc="w", ec="k"))
     79              ax.arrow(x,y-0.02, 0.09,-0.11, head_width=0.01, head_length=0.02, fc='k', ec='k')
     80              ax.arrow(x,y-0.02, -0.09,-0.11, head_width=0.01, head_length=0.02, fc='k', ec='k')
     81              x+=0.05
     82              y-=0.1
     83              if i>1:tree[i-2][1]-=1
     84         elif tree[i][1]==1:
     85              fig1.text(x+0.05,y, tree[i][0],ha="center",size=21)
     86              x+=0.05
     87              y-=0.1
     88         else:
     89              fig1.text(x,y, tree[i][0],ha="center",size=21,bbox=dict(boxstyle="square", fc="w", ec="k"))
     90              x-=0.25
     91              y+=0.1
     92              j=i-2
     93              while tree[j][1]==0:
     94                  j=j-2
     95                  x+=0.1
     96                  y+=0.2
     97              tree[j][1]-=1
     98     ax.xaxis.set_visible(False)
     99     ax.yaxis.set_visible(False)
    100     plt.draw()
    101     plt.show()
    102 
    103 dataset,labels=createDataSet()
    104 tree= ID3(dataset,labels) #[["有房子",2],["否",1],["有工作",2],["否",1],["拒绝",0],["是",1],["同意",0],["是",1],["同意",0]]
    105 showT(tree)

  • 相关阅读:
    关于JAVA中的static方法、并发问题以及JAVA运行时内存模型
    【设计模式】抽象工厂模式
    spring mvc4.1.6 + spring4.1.6 + hibernate4.3.11 + mysql5.5.25 开发环境搭建及相关说明
    struts2.3.24 + spring4.1.6 + hibernate4.3.11+ mysql5.5.25开发环境搭建及相关说明
    git中Please enter a commit message to explain why this merge is necessary.
    扒一扒开源世界有哪些licenses?
    string.Format出现异常:输入字符串的格式不正确 Exception during StringFormat
    node-glob学习
    js中对URL进行转码与解码
    程序员如何修炼管理思维
  • 原文地址:https://www.cnblogs.com/qw12/p/5676613.html
Copyright © 2011-2022 走看看