最简代码:
1 #简单的决策树分类 2 from sklearn import tree 3 features = [[300,2],[450,2],[200,8],[150,9]] 4 labels = ['apple','apple','orange','orange'] 5 clf = tree.DecisionTreeClassifier() 6 clf = clf.fit(features,labels) 7 print(clf.predict([[400,6]]))
预测代码:
代码:
1 # -*- coding: UTF-8 -*- 2 from sklearn.preprocessing import LabelEncoder, OneHotEncoder 3 from sklearn.externals.six import StringIO 4 from sklearn import tree 5 import pandas as pd 6 import numpy as np 7 import pydotplus 8 9 if __name__ == '__main__': 10 with open('datalenses.txt', 'r') as fr: #加载文件 11 lenses = [inst.strip().split(' ') for inst in fr.readlines()] #处理文件 12 lenses_target = [] #提取每组数据的类别,保存在列表里 13 for each in lenses: 14 lenses_target.append(each[-1]) 15 16 lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] #特征标签 17 lenses_list = [] #保存lenses数据的临时列表 18 lenses_dict = {} #保存lenses数据的字典,用于生成pandas 19 for each_label in lensesLabels: #提取信息,生成字典 20 for each in lenses: 21 lenses_list.append(each[lensesLabels.index(each_label)]) 22 lenses_dict[each_label] = lenses_list 23 lenses_list = [] 24 # print(lenses_dict) #打印字典信息 25 lenses_pd = pd.DataFrame(lenses_dict) #生成pandas.DataFrame 26 print(lenses_pd) #打印pandas.DataFrame 27 le = LabelEncoder() #创建LabelEncoder()对象,用于序列化 28 for col in lenses_pd.columns: #序列化 29 lenses_pd[col] = le.fit_transform(lenses_pd[col]) 30 print(lenses_pd) #打印编码信息 31 32 clf = tree.DecisionTreeClassifier(max_depth = 4) #创建DecisionTreeClassifier()类 33 clf = clf.fit(lenses_pd.values.tolist(), lenses_target) #使用数据,构建决策树 34 print(lenses_target) 35 print(clf.predict([[1,1,1,0]])) #预测