zoukankan      html  css  js  c++  java
  • Python实现ID3算法

      自己用Python写的数据挖掘中的ID3算法,现在觉得Python是实现算法的最好工具:

      先贴出ID3算法的介绍地址http://wenku.baidu.com/view/cddddaed0975f46527d3e14f.html

      自己写的ID3算法

      1 from __future__ import division
      2 import math
      3 
      4 table = {'age': {'young', 'middle', 'old'}, 'income': {'high', 'middle', 'low'},
      5          'student': {'yes', 'no'}, 'credit': {'good', 'superior'}, 'buy computer': {'yes', 'no'}}
      6 attrIndex = {'age': 0, 'income': 1, 'student': 2, 'credit': 3, 'buy computer': 4}
      7 attrList = ['age', 'income', 'student', 'credit']
      8 allDataSet = [
      9     ['young', 'high', 'no', 'good', 'no'], ['young', 'high', 'no', 'superior', 'no'],
     10     ['middle', 'high', 'no', 'superior', 'yes'], ['old', 'middle', 'no', 'good', 'yes'],
     11     ['young', 'middle', 'no', 'good', 'no'], ['young', 'low', 'yes', 'good', 'yes'],
     12     ['middle', 'high', 'yes', 'good', 'yes'], ['old', 'middle', 'no', 'superior', 'no'],
     13     ['young', 'high', 'yes', 'good', 'yes'], ['middle', 'middle', 'no', 'good', 'no']
     14 ]
     15 
     16 #求熵
     17 def entropy(attr, dataSet):
     18     valueCount = {v: {'yes': 0, 'no': 0, 'count': 0} for v in table[attr]}
     19     for row in dataSet:
     20         vName = row[attrIndex[attr]]
     21         decAttrVal = row[attrIndex['buy computer']] # 'yes' or 'no'
     22         valueCount[vName]['count'] = valueCount[vName]['count'] + 1
     23         valueCount[vName][decAttrVal] = valueCount[vName][decAttrVal] + 1
     24     infoMap = {v: 0 for v in table[attr]}
     25     for v in valueCount:
     26         if valueCount[v]['count'] == 0:
     27             infoMap[v] = 0
     28         else:
     29             p1 = valueCount[v]['yes'] / valueCount[v]['count']
     30             p2 = valueCount[v]['no'] / valueCount[v]['count']
     31             infoMap[v] = - ((0 if p1 == 0 else p1 * math.log(p1, 2)) + (0 if p2 == 0 else p2 * math.log(p2, 2)))
     32     s = 0
     33     for v in valueCount:
     34         s = s + valueCount[v]['count']
     35     propMap = {v: (valueCount[v]['count'] / s) for v in valueCount}
     36     i = 0
     37     for v in valueCount:
     38         i = i + infoMap[v] * propMap[v]
     39     return i
     40 
     41 #定义节点的数据结构
     42 class Node(object):
     43     def __init__(self, attrName):
     44         if attrName != '':
     45             self.attr = attrName
     46             self.childNodes = {v:Node('') for v in table[attrName]}
     47 
     48 #数据筛选
     49 def filtrate(dataSet, condition):
     50     result = []
     51     for row in dataSet:
     52         if row[attrIndex[condition['attr']]] == condition['val']:
     53             result.append(row)
     54     return result
     55 #求最大信息熵
     56 def maxEntropy(dataSet, attrList):
     57     if len(attrList) == 1:
     58         return attrList[0]
     59     else:
     60         attr = attrList[0]
     61         maxE = entropy(attr, dataSet)
     62         for a in attrList:
     63             if maxE < entropy(a, dataSet):
     64                 attr = a
     65         return attr
     66 #判断构建是否结束,当所有的决策属性都相等的时候,就不用在构建决策树了
     67 def endBuild(dataSet):
     68     if len(dataSet) == 1:
     69         return True
     70     buy = dataSet[0][attrIndex['buy computer']]
     71     for row in dataSet:
     72         if buy != row[attrIndex['buy computer']]:
     73             return False
     74 #构建决策树
     75 def buildDecisionTree(dataSet, root, attrList):
     76     if len(attrList) == 0 or endBuild(dataSet):
     77         root.attr = 'buy computer'
     78         root.result = dataSet[0][attrIndex['buy computer']]
     79         root.childNodes = {}
     80         return
     81     attr = root.attr
     82     for v in root.childNodes:
     83         childDataSet = filtrate(dataSet, {"attr":attr, "val":v})
     84         if len(childDataSet) == 0:
     85             root.childNodes[v] = Node('buy computer')
     86             root.childNodes[v].result = 'no'
     87             root.childNodes[v].childNodes = {}
     88             continue
     89         else:
     90             childAttrList = [a for a in attrList]
     91             childAttrList.remove(attr)
     92             if len(childAttrList) == 0:
     93                 root.childNodes[v] = Node('buy computer')
     94                 root.childNodes[v].result = childDataSet[0][attrIndex['buy computer']]
     95                 root.childNodes[v].childNodes = {}
     96             else:
     97                 childAttr = maxEntropy(childDataSet, childAttrList)
     98                 root.childNodes[v] = Node(childAttr)
     99                 buildDecisionTree(childDataSet, root.childNodes[v], childAttrList)
    100 #预测结果
    101 def predict(root, row):
    102     if root.attr == 'buy computer':
    103         return root.result
    104     root = root.childNodes[row[attrIndex[root.attr]]]
    105     return predict(root, row)
    106 
    107 rootAttr = maxEntropy(allDataSet, attrList)
    108 rootNode = Node(rootAttr)
    109 print rootNode.attr
    110 buildDecisionTree(allDataSet, rootNode, attrList)
    111 print predict(rootNode, ['old', 'low', 'yes', 'good'])

             欢迎大家提出建议

  • 相关阅读:
    Active Directory如何用C#进行增加、删除、修改、查询用户与组织单位!
    showModalDialog和showModelessDialog的使用
    如何在GridView中使用DataFromatString
    GridView/DataGrid单元格不换行的问题
    要Gmail、Orkut邀请的请留下你的邮箱
    How to reset security settings back to the defaults
    ASP.NET 2.0 学习笔记 1: session 与 script 应用
    关闭主窗体而不退出主程序 以及如何获取操作系统的关闭、注销信息
    ASP.NET 2.0 学习笔记 2: 页面间传值
    Windows 系统常用设置方法与技巧
  • 原文地址:https://www.cnblogs.com/ArtsCrafts/p/ID3.html
Copyright © 2011-2022 走看看