zoukankan      html  css  js  c++  java
  • 决策树实战2-使用决策树预测隐形眼镜类型

    这里是3.x版本的Python,对代码做了一些修改。
    其中画图的函数直接使用的是原代码中的函数,也做了一些修改。

    书本配套的数据和2.7版本的源码可以在这里获取 :https://www.manning.com/books/machine-learning-in-action

    from math import log
    from ch3.treePlotter import createPlot
    
    def calShannonEntropy(dataset):
        """
        计算香浓熵
        :param dataset: 输入数据集
        :return: 熵
        """
        num = len(dataset)
        label_liat = {}
        for x in dataset:
            label = x[-1]  # the last column is label
            if label not in label_liat.keys():
                label_liat[label]=0
            label_liat[label] += 1
        shannonEnt = 0.0
        for key in label_liat:
            prob = float(label_liat[key]/num)
            shannonEnt -= prob * log(prob,2)
    
        # print("数据集的香浓熵为%f" % shannonEnt)
        return shannonEnt
    
    
    def splitDate(dataset, axis, value):
        """
        根据某个特征划分数据集,
        :param dataset: 输入数据集
        :param axis: 数据集的每一列表示一个特征,axis取不同的值表示取不同的特征
        :param value: 根据这个特征划分的类别标记,在二叉树中常为2个,是或者否
        :return: 返回去掉了某个特征并且值是value的数据
        """
        newdataset = []
        for x in dataset:
            if x[axis] == value:
                reduceFeat = x[:axis]
                reduceFeat.extend(x[axis+1:])
                newdataset.append(reduceFeat)
        return newdataset
    
    def keyFeatureSelect(dataset):
        """
        通过信息增益判断哪个特征是关键特征并返回这个特征
        :param dataset: 输入数据集
        :return: 特征
        """
        num_feature = len(dataset[0])-1
        base_entropy = calShannonEntropy(dataset)
        bestInfogain = 0
        bestfeature = -1
        for i in range(num_feature):
            featlist = [example[i] for example in dataset]
            feat_value = set(featlist)
            feat_entropy = 0
            for value in feat_value:
                subset = splitDate(dataset,i,value)
                prob = len(subset)/float(len(dataset))
                feat_entropy += prob * calShannonEntropy(subset)
            infoGain = base_entropy - feat_entropy
            # print("第%d个特征的信息增益%0.3f" %(i,infoGain))
            if (infoGain > bestInfogain):
                bestInfogain = infoGain
                bestfeature = i
    
        # print("第%d个特征最关键" % i)
        return  bestfeature
    
    
    def voteClass(classlist):
        """
        通过投票的方式决定类别
        :param classlist: 输入类别的集合
        :return: 大多数类别的标签
        """
        import operator
    
        classcount = {}
        for x in classlist:
            if x not in classcount.keys():classcount[x]=0
            classcount += 1
        sortclass = sorted(classcount.iteritems(),key = operator.itemgetter(1),reverse=True)
    
        return sortclass[0][0]
    
    
    def createTree(dataset,labels):
        """
        递归构建树
        :param dataset: dataset
        :param labels: labels of feature
        :return:树
        """
        labelsCopy = labels[:]          # 原代码没有这个,结果第一次运行之后第一个特征被删除了,所以做了修改
        classList = [example[-1] for example in dataset]
        if classList.count(classList[0]) == len(classList): #判断所有类标签是否相同
            return classList[0]
        if len(dataset[0]) == 1: # 是否历遍了所有特征(是否剩下一个特征)
            return voteClass(classList)
        bestFeat = keyFeatureSelect(dataset)
        bestFeatLabel = labelsCopy[bestFeat]
        tree = {bestFeatLabel:{}} # 使用字典实现树
        del labelsCopy[bestFeat]
        featValues = [example[bestFeat] for example in dataset]
        uniqueValue = set(featValues)
        for value in uniqueValue:
            subLabels = labelsCopy[:] #复制类标签到新的列表中,保证每次递归调用不改变原始列表
            tree[bestFeatLabel][value] = createTree(splitDate(dataset,bestFeat,value),subLabels)
        return tree
    
    
    def decTreeClassify(inputTree, featLables, testVec):
        """
        使用决策树模型进行分类
        :param inputTree:
        :param featLables:
        :param testVec:
        :return:
        """
    
    
        firstStr = list(inputTree.keys())[0]    # 根节点
        secondDict = inputTree[firstStr]        # 节点下的值
        featIndex = featLables.index(firstStr)  # 获得第一个特征的label对应数据的位置
        for key in secondDict.keys():           # secondDict.keys()表示一个特征的取值
            if testVec[featIndex] == key:       # 比较测试向量中的值和树的节点值
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = decTreeClassify(secondDict[key], featLables, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
    
    
    def storeTree(inputTree, filename):
        """
        store the trained Tree.
        :param inputTree: the the trained Tree
        :param filename: save tree as file name
        :return: None
        """
        import pickle
        fw = open(filename,'wb')
        pickle.dump(inputTree,fw)
        fw.close()
        print("tree save as", filename)
    
    
    def grabTree(filename):
        """
        read stored tree from disk
        :param filename: the goal file
        :return: Tree
        """
        print("load tree from disk...")
        import pickle
        fr = open(filename,"rb")
        return pickle.load(fr)
    
    
    
    if __name__== '__main__':
    
        fr = open('lenses.txt')
        lense = [inst.strip().split('	') for inst in fr.readlines()]
        train_set = lense[1:]
        test_set = lense[0]
        lenseLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
        lenseTree = createTree(train_set, lenseLabels)
        createPlot(lenseTree)
        storeTree(lenseTree, 'lenseTree.txt')
        restoreTree = grabTree('lenseTree.txt')
        print(restoreTree)
        predict = decTreeClassify(restoreTree,lenseLabels,test_set)
        print(predict)
    

    画出来的图:
    pic
    运行结果:

    {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'myope': 'hard', 'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'myope': 'no lenses', 'hyper': 'soft'}}, 'young': 'soft'}}}}}}

    预测结果:

    no lenses

    参考《机器学习实战》

  • 相关阅读:
    高级特性(7)- 高级AWT
    洛谷 P1948 [USACO08JAN]电话线Telephone Lines
    洛谷 P2015 二叉苹果树
    洛谷 P2014 选课
    洛谷 P1560 [USACO5.2]蜗牛的旅行Snail Trails(不明原因的scanf错误)
    cogs 10. 信号无错传输
    cogs 9. 中心台站建设。。。
    洛谷 P1731 生日蛋糕
    洛谷 P1092 虫食算
    洛谷 P1034 矩形覆盖
  • 原文地址:https://www.cnblogs.com/siucaan/p/9623121.html
Copyright © 2011-2022 走看看