zoukankan      html  css  js  c++  java
  • [Python]实现简单决策树

    基本思路:

      通过香农熵来决定每一层使用哪一种标签做分类,分类后,通过多数表决法来决定该层两个节点的类别。每次消耗一个标签,所以一共需要递归“标签个数”层。

     1 # -*- coding:utf-8 -*-
     2 import math
     3 import operator
     4 from collections import Counter
     5 
     6 def shannon_ent(dat):
     7   siz = len(dat)
     8   return 0.0 - reduce(lambda x, y: x + y,
     9     map(lambda each: float(each)/siz * math.log(float(each)/siz, 2),
    10     Counter(map(lambda each: each[-1], dat)).values()))
    11 
    12 def split_dataset(dat, axis, val):
    13   ret = filter(lambda each: each[axis] == val, dat)
    14   return map(lambda each: each[:axis]+each[axis+1:], ret)
    15 
    16 def choose_best_feature(dat):
    17   feature_num = len(dat[0]) - 1
    18   base_ent = shannon_ent(dat)
    19   best_info_gain = 0.0
    20   best_feature = -1
    21   for i in range(feature_num):
    22     feature_list = set([each[i] for each in dat])
    23     cur_ent = reduce(lambda x, y: x + y,
    24               map(lambda val: len(split_dataset(dat, i, val))/float(len(dat))*shannon_ent(split_dataset(dat, i, val)),
    25               feature_list))
    26     info_gain = base_ent - cur_ent
    27     if info_gain > best_info_gain:
    28       best_info_gain, best_feature = info_gain, i
    29   return best_feature
    30 
    31 def majority_count(class_list):
    32   class_dict = sorted(dict(Counter(class_list)).iteritems(), key=operator.itemgetter(1))
    33   return class_dict[-1][0]
    34 
    35 def create_tree(dat, label):
    36   class_list = map(lambda each: each[-1], dat)
    37   if class_list.count(class_list[0]) == len(class_list):
    38     return class_list[0]
    39   if len(dat[0]) == 1:
    40     return majority_count(class_list)
    41   best_feature = choose_best_feature(dat)
    42   best_label = label[best_feature]
    43   d_tree = {best_label:{}}
    44   del(label[best_feature])
    45   feature_val = map(lambda each: each[best_feature], dat)
    46   val_set = set(feature_val)
    47   def _update_tree(val):
    48     sub_label = label[:]
    49     d_tree[best_label][val] = create_tree(split_dataset(dat, best_feature, val), sub_label)
    50   map(_update_tree, val_set)
    51   return d_tree
    52 
    53 d = [[1,1,'y'], [1,1,'y'], [1,0,'n'], [0,1,'n'], [0,1,'n']]
    54 l = ['no surfacing', 'flippers']
    55 
    56 print create_tree(d, l)
  • 相关阅读:
    1774:大逃杀
    Angular实现简单数据计算与删除
    IDEA 如何搭建maven 安装、下载、配置(图文)
    win10 Java JDK环境变量配置
    Nginx学习使用
    ASP.NET Core中返回 json 数据首字母大小写问题
    mysql使用遇到的问题
    线程同步以及AutoResetEvent
    Device Class
    Xamarin.Forms之布局压缩
  • 原文地址:https://www.cnblogs.com/kirai/p/6222832.html
Copyright © 2011-2022 走看看