zoukankan      html  css  js  c++  java
  • 决策树算法

    1.理论

    2.代码

      2.1 训练数据

    RID,age,income,student,credit_rating,class:buy_computer
    1,youth,high,no,fair,no
    2,youth,high,no,excellent,no
    3,middle_age,high,no,fair,yes
    4,senior,medium,no,fair,yes
    5,senior,low,yes,fair,yes
    6,senior,low,yes,excellent,no
    7,middle_age,low,yes,excellent,yes
    8,youth,medium,no,fair,no
    9,youth,low,yes,fair,yes
    10,senior,medium,yes,fair,yes
    11,youth,medium,yes,excellent,yes
    12,middle_age,medium,no,excellent,yes
    13,middle_age,high,yes,fair,yes
    14,senior,medium,no,excellent,no

      2.2 代码

    from sklearn.feature_extraction import DictVectorizer
    import csv
    from sklearn import preprocessing
    from sklearn import tree
    from sklearn.externals.six import StringIO

    #Read Data
    Data = open(r'D:lernMLData.csv','r', encoding="utf-8")
    reader = csv.reader(Data)
    headers = next(reader)#获取第一行,特征

    print(headers)

    featureList = []
    labelList = []

    for row in reader :
    labelList.append(row[len(row)-1])
    rowOict ={}
    for i in range(1,len(row)-1):
    # print(row[i])
    rowOict[headers[i]]=row[i]
    # print("rowOict:",rowOict)
    featureList.append(rowOict)

    print(featureList)

    #vectorize features
    vec = DictVectorizer()
    dummyX=vec.fit_transform(featureList).toarray()
    print("dummyX:"+str(dummyX))
    print(vec.get_feature_names())

    print("labelList:"+str(labelList))

    #vectorize class labels
    lb= preprocessing.LabelBinarizer()
    dummyY=lb.fit_transform(labelList)
    print("dummyY:"+str(dummyY))

    #using Dicision Tree for classification
    clf =tree.DecisionTreeClassifier(criterion='entropy')
    clf=clf.fit(dummyX,dummyY)
    print("clf:"+str(clf))

    #visulize model
    with open("DT.dot",'w') as f:
    f= tree.export_graphviz(clf,feature_names=vec.get_feature_names(),out_file=f)

    #predict result
    oneRowX=dummyX[0,:]
    print("oneRowX"+str(oneRowX))

    newRowX=oneRowX
    newRowX[0]=1
    newRowX[2]=0
    print("newRowX"+str(newRowX))

    predictedY=clf.predict([newRowX])
    print("predictedY"+str(predictedY))

      2.3 结果

    digraph Tree {
    node [shape=box] ;
    0 [label="age=middle_age <= 0.5 entropy = 0.94 samples = 14 value = [5, 9]"] ;
    1 [label="student=no <= 0.5 entropy = 1.0 samples = 10 value = [5, 5]"] ;
    0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
    2 [label="credit_rating=excellent <= 0.5 entropy = 0.722 samples = 5 value = [1, 4]"] ;
    1 -> 2 ;
    3 [label="entropy = 0.0 samples = 3 value = [0, 3]"] ;
    2 -> 3 ;
    4 [label="income=low <= 0.5 entropy = 1.0 samples = 2 value = [1, 1]"] ;
    2 -> 4 ;
    5 [label="entropy = 0.0 samples = 1 value = [0, 1]"] ;
    4 -> 5 ;
    6 [label="entropy = 0.0 samples = 1 value = [1, 0]"] ;
    4 -> 6 ;
    7 [label="age=youth <= 0.5 entropy = 0.722 samples = 5 value = [4, 1]"] ;
    1 -> 7 ;
    8 [label="credit_rating=fair <= 0.5 entropy = 1.0 samples = 2 value = [1, 1]"] ;
    7 -> 8 ;
    9 [label="entropy = 0.0 samples = 1 value = [1, 0]"] ;
    8 -> 9 ;
    10 [label="entropy = 0.0 samples = 1 value = [0, 1]"] ;
    8 -> 10 ;
    11 [label="entropy = 0.0 samples = 3 value = [3, 0]"] ;
    7 -> 11 ;
    12 [label="entropy = 0.0 samples = 4 value = [0, 4]"] ;
    0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
    }


      2.3.1 结果可视化:graghviz小工具
    
    
    
  • 相关阅读:
    Django基础(三)
    Python--时间模块.s(基本操作)
    Python--Pandas.1(Series的概念和创建,索引和切片,常用的基本操作)
    Python--Numpy.s(numpy的创建,通用函数,索引和切片,随机数,数据存读)
    主键约束 primary key
    not null 非空约束
    unique key 唯一约束
    表约束
    mysql操作
    python学习笔记
  • 原文地址:https://www.cnblogs.com/yrm1160029237/p/9858972.html
Copyright © 2011-2022 走看看