zoukankan      html  css  js  c++  java
  • 机器学习之--决策树递归算法实现

    import numpy as np
    import math
    
    #产生数据的函数
    def createdatabase():
        dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
        labels = [['no surfacing'],['flippers']]
        return dataSet,labels
    dataSet,labels = createdatabase()
    print('dataSet:',dataSet)
    print()
    
    #求数据的香农熵                    熵越大 混合的数据越多
    def XN(dataSet):
        mydict = {}
        Sum = 0
        datasize = len(dataSet)
        for i in dataSet:
            mydict[i[-1]] = mydict.get(i[-1],0) + 1
        for key in mydict:
            P = mydict[key] / datasize
            Sum -= P * math.log(P,2)
        # print('dict:',mydict)
        # print('XN:',Sum)
        return Sum
    # XN(dataSet)
    
    def D_split(dataSet,axis,value):                #按某一列的某个值分数据    以dataset[axis]的value值分类  返回的数组比原数据少一列(少的axis这列)
        # print("data:{},axis:{},value:{}".format(dataSet,axis,value))
        result = []
        for i in dataSet:
            # print(i[axis],i[axis] == value,type(i[axis]))
            if i[axis] == value:
                # print('测试每行', i[axis])
                data1 = i[:axis]
                data2 = i[axis + 1:]
                data1.extend(data2)
                result.append(data1)
        return result
    # partdata = D_split(dataSet,0,1)
    #                                           partdata [[1, 'yes'], [1, 'yes'], [0, 'yo']]
    
    def chooseaxis(dataSet):                    #选择最佳axis       返回最佳特征值的序号
        datasize = len(dataSet)                 #数据行数(长度)
        baseXN = XN(dataSet)                    #原始数据的熵值
        bestaxis = 0  # 最好的axis选择默认为0
        for axis in range(len(dataSet[0]) - 1):                 #axis 为列号
            value_list = [row[axis] for row in dataSet]             #把该列的所有value组成一个列表
            value_list = set(value_list)                             #去重
            newXN = 0                                                #设置新熵值
            for value in value_list:
                partdata = D_split(dataSet,axis,value)
                P = len(partdata) / datasize                         #求该value的概率
                newXN += P * XN(partdata)
            # print('partdata:{},P:{},axis为:{},newXH:{}'.format(partdata,P,axis,newXN))            # axis为:0,newXH:0.5509775004326937
                                                                        # axis为:1,newXH:0.8
            if newXN < baseXN:
                baseXN = newXN
                bestaxis = axis
        print('bestaxis:{},XN:{}'.format(bestaxis,baseXN))           # bestaxis:0,XN:0.5509775004326937
        return bestaxis
    
    def major(classlist):                           #少数服从多数函数   返回较多的类型
        classcount = {}
        for i in classlist:
            classcount[i] = classcount.get(i,0) + 1
        classcount = sorted(classcount,key=classcount.get)
        return classcount[-1]
    
    classlist = [i[2] for i in dataSet]
    print('classlist:',classlist)
    # mydict = major(classlist)
    
    def createtree(dataSet,labels):                       #构造树
        classlist = [i[-1] for i in dataSet]
        if len(dataSet) == classlist.count(classlist[0]):
            return classlist[0]
        if len(dataSet[0]) == 1:
            return major(classlist)
        axis = chooseaxis(dataSet)
        label_choose = labels[axis]
        # print('label_choose',label_choose)
        del labels[axis]
        mytree = {label_choose[0]:{}}                                         #定义需要返回的树 以当前分类特征为key
        for value in [row[axis] for row in dataSet]:
            newlables = labels[:]                                               #如果直接传lables,列表元素传值是传的引用,会影响,所以这里用切片切个一样的副本,不能单纯的'=',不然还是引用
            mytree[label_choose[0]][value] = createtree(D_split(dataSet,axis,value),newlables)
        print("mytree:
    ",mytree)
        return mytree
    createtree(dataSet,labels)
    # 结果如下:
    # {'no surfacing': {1: {'flippers': {1: 'yes', 0: 'no'}}, 0: 'no'}}
  • 相关阅读:
    vue-cli(脚手架)学习
    vue-cli(脚手架)
    js时间戳转时间格式
    jQ获取窗口尺寸
    前端加密MD5
    vue项目准备工作(一)
    Oracle数据错删找回
    正则表达式匹配【全角字符】
    数据库分区、分表、分库、分片
    oracle的 分表 详解 -----表分区
  • 原文地址:https://www.cnblogs.com/cxhzy/p/10627390.html
Copyright © 2011-2022 走看看