zoukankan      html  css  js  c++  java
  • 决策树

    基础内容可以直接看这篇博客

    下面的demo是使用决策树算法的一个例子,使用的数据链接如下:

    https://files.cnblogs.com/files/henuliulei/决策树.zip

    from sklearn.feature_extraction import DictVectorizer
    from sklearn import tree
    from sklearn import preprocessing
    import csv
    # 读入数据
    Dtree = open(r'AllElectronics.csv', 'r')
    reader = csv.reader(Dtree)
    
    # 获取第一行数据
    headers = reader.__next__()
    print(headers)
    
    # 定义两个列表
    featureList = []
    labelList = []
    
    #
    for row in reader:
        # 把label存入list
        labelList.append(row[-1])#类别标签
        rowDict = {}
        for i in range(1, len(row)-1):
            #建立一个数据字典
            rowDict[headers[i]] = row[i]
        # 把数据字典存入list
        featureList.append(rowDict)
    
    print(featureList)
    # 把数据转换成01表示
    vec = DictVectorizer()
    x_data = vec.fit_transform(featureList).toarray()
    print("x_data: " + str(x_data))
    
    # 打印属性名称
    print(vec.get_feature_names())
    
    # 打印标签
    print("labelList: " + str(labelList))
    
    # 把标签转换成01表示
    lb = preprocessing.LabelBinarizer()
    y_data = lb.fit_transform(labelList)
    print("y_data: " + str(y_data))
    # 创建决策树模型
    model = tree.DecisionTreeClassifier(criterion='entropy')##建立决策树模型,基于信息熵
    # 输入数据建立模型
    model.fit(x_data, y_data)
    # 测试
    x_test = x_data[0]
    print("x_test: " + str(x_test))
    
    predict = model.predict(x_test.reshape(1,-1))
    print("predict: " + str(predict))
    # 导出决策树
    # pip install graphviz
    # http://www.graphviz.org/
    import graphviz
    
    dot_data = tree.export_graphviz(model,
                                    out_file = None,
                                    feature_names = vec.get_feature_names(),
                                    class_names = lb.classes_,
                                    filled = True,
                                    rounded = True,
                                    special_characters = True)
    graph = graphviz.Source(dot_data)
    graph.render('computer')#生成一个决策树图表文件

    下面的demo是用决策树进行线性二分类

     1 import matplotlib.pyplot as plt
     2 import numpy as np
     3 from sklearn.metrics import classification_report
     4 from sklearn import tree
     5 # 载入数据
     6 data = np.genfromtxt("LR-testSet.csv", delimiter=",")
     7 x_data = data[:,:-1]
     8 y_data = data[:,-1]
     9 
    10 plt.scatter(x_data[:,0],x_data[:,1],c=y_data)
    11 plt.show()
    12 # 创建决策树模型
    13 model = tree.DecisionTreeClassifier()
    14 # 输入数据建立模型
    15 model.fit(x_data, y_data)
    16 # 导出决策树
    17 import graphviz # http://www.graphviz.org/
    18 
    19 dot_data = tree.export_graphviz(model,
    20                                 out_file = None,
    21                                 feature_names = ['x','y'],
    22                                 class_names = ['label0','label1'],
    23                                 filled = True,
    24                                 rounded = True,
    25                                 special_characters = True)
    26 graph = graphviz.Source(dot_data)
    27 # 获取数据值所在的范围
    28 x_min, x_max = x_data[:, 0].min() - 1, x_data[:, 0].max() + 1
    29 y_min, y_max = x_data[:, 1].min() - 1, x_data[:, 1].max() + 1
    30 
    31 # 生成网格矩阵
    32 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
    33                      np.arange(y_min, y_max, 0.02))
    34 
    35 z = model.predict(np.c_[xx.ravel(), yy.ravel()])# ravel与flatten类似,多维数据转一维。flatten不会改变原始数据,ravel会改变原始数据
    36 z = z.reshape(xx.shape)
    37 # 等高线图
    38 cs = plt.contourf(xx, yy, z)
    39 # 样本散点图
    40 plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data)
    41 plt.show()
    42 predictions = model.predict(x_data)
    43 print(classification_report(predictions,y_data))

    下面这个demo是用决策树进行非线性二分类

     1 import matplotlib.pyplot as plt
     2 import numpy as np
     3 from sklearn.metrics import classification_report
     4 from sklearn import tree
     5 from sklearn.model_selection import train_test_split
     6 
     7 # 载入数据
     8 data = np.genfromtxt("LR-testSet2.txt", delimiter=",")
     9 x_data = data[:, :-1]
    10 y_data = data[:, -1]
    11 
    12 plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data)
    13 plt.show()
    14 #分割数据
    15 x_train,x_test,y_train,y_test = train_test_split(x_data, y_data)
    16 
    17 # 创建决策树模型
    18 # max_depth,树的深度
    19 # min_samples_split 内部节点再划分所需最小样本数
    20 model = tree.DecisionTreeClassifier(max_depth=7,min_samples_split=4)
    21 # 输入数据建立模型
    22 model.fit(x_train, y_train)
    23 # 导出决策树
    24 import graphviz # http://www.graphviz.org/
    25 
    26 dot_data = tree.export_graphviz(model,
    27                                 out_file = None,
    28                                 feature_names = ['x','y'],
    29                                 class_names = ['label0','label1'],
    30                                 filled = True,
    31                                 rounded = True,
    32                                 special_characters = True)
    33 graph = graphviz.Source(dot_data)
    34 # 获取数据值所在的范围
    35 x_min, x_max = x_data[:, 0].min() - 1, x_data[:, 0].max() + 1
    36 y_min, y_max = x_data[:, 1].min() - 1, x_data[:, 1].max() + 1
    37 
    38 # 生成网格矩阵
    39 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
    40                      np.arange(y_min, y_max, 0.02))
    41 
    42 z = model.predict(np.c_[xx.ravel(), yy.ravel()])# ravel与flatten类似,多维数据转一维。flatten不会改变原始数据,ravel会改变原始数据
    43 z = z.reshape(xx.shape)
    44 # 等高线图
    45 cs = plt.contourf(xx, yy, z)
    46 # 样本散点图
    47 plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data)
    48 plt.show()
    49 predictions = model.predict(x_train)
    50 print(classification_report(predictions,y_train))
    51 predictions = model.predict(x_test)
    52 print(classification_report(predictions,y_test))
     1 from sklearn import tree
     2 import numpy as np
     3 # 载入数据
     4 data = np.genfromtxt("cart.csv", delimiter=",")
     5 x_data = data[1:,1:-1]
     6 y_data = data[1:,-1]
     7 # 创建决策树模型
     8 model = tree.DecisionTreeClassifier()
     9 # 输入数据建立模型
    10 model.fit(x_data, y_data)
    11 # 导出决策树
    12 import graphviz # http://www.graphviz.org/
    13 
    14 dot_data = tree.export_graphviz(model,
    15                                 out_file = None,
    16                                 feature_names = ['house_yes','house_no','single','married','divorced','income'],
    17                                 class_names = ['no','yes'],
    18                                 filled = True,
    19                                 rounded = True,
    20                                 special_characters = True)
    21 graph = graphviz.Source(dot_data)
    22 graph.render('cart')
    23 print(graph)
  • 相关阅读:
    模板网站
    用servlet和jsp做探索数据库
    Hibernate和jsp做数据库单表的增删改查
    拦截器
    校验器-对提交的用户名和密码进行过滤
    使用my exclipse对数据库进行操作(4)
    如何正确关闭游戏服务器
    Ehcache 入门详解
    自动重置 Language Level默认为5与 Java Complier默认为1.5
    洪均生谈初学者练习(怎样认识太级拳和怎样进行练习)
  • 原文地址:https://www.cnblogs.com/henuliulei/p/11821506.html
Copyright © 2011-2022 走看看