zoukankan      html  css  js  c++  java
  • iris数据集 决策树实现分类并画出决策树

     1 # coding=utf-8
     2 
     3 import pandas as pd
     4 from sklearn.model_selection import train_test_split
     5 from sklearn import tree
     6 from sklearn.metrics import precision_recall_curve  #准确率与召回率
     7 import numpy as np
     8 #import graphviz
     9 
    10 import os
    11 os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
    12 
    13 
    14 
    15 def get_data():
    16     file_path = "Iris.xlsx"
    17 
    18     data = pd.read_excel(file_path)
    19     loandata = pd.DataFrame(data)
    20     ncol = (len(loandata.keys()))
    21     print(ncol)
    22     # l = list(data.head(0))  #获取表头
    23     # print(l)
    24 
    25     feature1 = []
    26     for i in range(ncol-1):
    27         feature1.append("feature"+str(i))
    28     print(feature1)
    29     iris_x = data.iloc[1:, :ncol-1]#此处有冒号,不显示最后一列
    30     iris_y = data.iloc[1:,ncol-1]#此处没有冒号,直接定位
    31 
    32     '''计算到底有几个类别'''
    33     from collections import Counter
    34     counter = Counter(iris_y)
    35     con = len(counter)
    36     print(counter.keys())
    37     class_names = []
    38     for i in range(con):
    39         class_names.append(list(counter.keys())[i])
    40     x_train, x_test, y_train, y_test = train_test_split(iris_x,iris_y)
    41     print(x_train)
    42     print(y_test)
    43    # return x_train, x_test, y_train, y_test
    44 
    45 
    46 #def dtfit(x_train, x_test, y_train, y_test):
    47 
    48     clf = tree.DecisionTreeClassifier()
    49     clf = clf.fit(x_train,y_train)
    50     predict_data = clf.predict(x_test)
    51     predict_proba = clf.predict_proba(x_test)
    52     from sklearn import metrics
    53     # Do classification task,
    54     # then get the ground truth and the predict label named y_true and y_pred
    55     classify_report = metrics.classification_report(y_test, clf.predict(x_test))
    56     confusion_matrix = metrics.confusion_matrix(y_train, clf.predict(x_train))
    57     overall_accuracy = metrics.accuracy_score(y_train, clf.predict(x_train))
    58     acc_for_each_class = metrics.precision_score(y_train,clf.predict(x_train), average=None)
    59     overall_accuracy = np.mean(acc_for_each_class)
    60     print(classify_report)
    61 
    62 
    63 
    64 
    65     import pydotplus
    66     dot_data = tree.export_graphviz(clf, out_file=None,feature_names=feature1, filled=True, rounded=True, special_characters=True,precision = 4)
    67     graph = pydotplus.graph_from_dot_data(dot_data)
    68     graph.write_pdf("workiris.pdf")
    69     return classify_report
    70 
    71 
    72 if __name__ == "__main__":
    73     x = get_data()
    74     #dtfit(x_train, x_test, y_train, y_test)

    数据地址:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

    保存后注意填写表头

  • 相关阅读:
    构建之法阅读笔记04
    团队项目
    构建之法阅读笔记03
    第6周学习进度
    求最大子数组03
    四则运算4(完结)
    第5周学习进度
    敏捷开发概述
    第4周学习进度
    构建之法阅读笔记01
  • 原文地址:https://www.cnblogs.com/shizhenqiang/p/8204986.html
Copyright © 2011-2022 走看看