zoukankan      html  css  js  c++  java
  • 决策树(Decision Tree)SkLearn

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    from sklearn.feature_extraction import DictVectorizer
    import csv
    from sklearn import tree
    from sklearn import preprocessing
    from sklearn.externals.six import StringIO
    
    # Read in the csv file and put features into list of dict and list of class label
    allElectronicsData = open(r'AllElectronics.csv', 'rb')
    reader = csv.reader(allElectronicsData)
    headers = reader.next()
    
    print(headers)
    
    featureList = []
    labelList = []
    
    for row in reader:
        labelList.append(row[len(row)-1]) # 取得每一行最后一个值 标签
        rowDict = {} #取得每一行的值 包含有字典的list
        for i in range(1, len(row)-1):
            rowDict[headers[i]] = row[i]
        featureList.append(rowDict)
    
    print(featureList)
    
    # Vetorize features
    vec = DictVectorizer()#将字典转换成00100的形式(1000)
    dummyX = vec.fit_transform(featureList) .toarray()
    
    print("dummyX: " + str(dummyX))
    print(vec.get_feature_names())
    
    print("labelList: " + str(labelList))
    
    # vectorize class labels
    lb = preprocessing.LabelBinarizer()#将标签转换成0,1
    dummyY = lb.fit_transform(labelList)
    print("dummyY: " + str(dummyY))
    
    # Using decision tree for classification
    # clf = tree.DecisionTreeClassifier()
    clf = tree.DecisionTreeClassifier(criterion='entropy')#信息熵
    clf = clf.fit(dummyX, dummyY)
    print("clf: " + str(clf))
    
    
    # Visualize model
    with open("allElectronicInformationGainOri.dot", 'w') as f:
        f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file=f)
    
    oneRowX = dummyX[0, :]
    print("oneRowX: " + str(oneRowX))
    
    newRowX = oneRowX
    newRowX[0] = 1
    newRowX[2] = 0
    print("newRowX: " + str(newRowX))
    
    predictedY = clf.predict(newRowX)
    print("predictedY: " + str(predictedY))
    

      

  • 相关阅读:
    PHP之常用设计模式
    MySQL之慢查询日志和通用查询
    mysql之找回误删数据
    PHPer未来路在何方...
    如何成为更优秀的程序员
    常见的 CSRF、XSS、sql注入、DDOS流量攻击
    API接口TOKEN设计
    成为更好的程序员的八中途径
    奉秉格言
    PHP优化与提升
  • 原文地址:https://www.cnblogs.com/wlc297984368/p/7462684.html
Copyright © 2011-2022 走看看