zoukankan      html  css  js  c++  java
  • python+sklearn实现决策树(分类树)

    整理今天的代码……
    采用的是150条鸢尾花的数据集fishiris.csv

    # 读入数据,把Name列取出来作为标签(groundtruth)
    import pandas as pd
    data = pd.read_csv('fishiris.csv')
    print(data.head(5))
    X = data.iloc[:, data.columns != 'Name']
    Y = data['Name'] 
    

    df.iloc[rows, columns]取出符合条件的列。查看数据读取是否正确(关于pandas使用最熟练的一条……orz),如果csv文件或者其他数据没有列名需要加上names=[]?

       SepalLength  SepalWidth  PetalLength  PetalWidth    Name
    0          5.1         3.5          1.4         0.2  setosa
    1          4.9         3.0          1.4         0.2  setosa
    2          4.7         3.2          1.3         0.2  setosa
    3          4.6         3.1          1.5         0.2  setosa
    4          5.0         3.6          1.4         0.2  setosa
    

    确认数据无误后就可以分出验证集和测试集,挺方便的!查看一下返回数据的格式和数据集好像是相同的:type(Xtrain):<class .pandas.core.frame.DataFrame'>

    # 分割验证集和训练集
    from sklearn.model_selection import train_test_split
    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,Y,test_size=0.2)
    

    数据准备好就可以建模了。注意,因为这个数据集里没有空缺值所以没管,但是sklearn里这个模块好像是不能处理空缺的?要手动写个函数填进去。

    # 导入分类树的模块
    from sklearn.tree import DecisionTreeClassifier
    
    # 需要整理一下序号,也就是更新df.index,前面的可以看到挺乱的,因为是随机取的
    for i in [Xtrain,Xtest,Ytrain,Ytest]: # 这里的意思是i依次为Xtrain,Xtest……并修改它们的index值
        print(i,'before')
        i.index = range(i.shape[0])
        print(i,'after changed')
    
    clf = DecisionTreeClassifier(random_state=3) # 初始化
    clf = clf.fit(Xtrain,Ytrain) # 拟合
    score_ = clf.score(Xtest, Ytest) # 验证集查看得分,这个得分好像就是分类的准确率
    
    # 可以输入数据送到训练好的模型里,输出预测的类
    y_pred = clf.predict(Xtest)
    

    看看:

    # 之前的index
         SepalLength  SepalWidth  PetalLength  PetalWidth
    19           5.1         3.8          1.5         0.3
    67           5.8         2.7          4.1         1.0
    6            4.6         3.4          1.4         0.3
    100          6.3         3.3          6.0         2.5
    39           5.1         3.4          1.5         0.2
    ..           ...         ...          ...         ...
    106          4.9         2.5          4.5         1.7
    25           5.0         3.0          1.6         0.2
    138          6.0         3.0          4.8         1.8
    84           5.4         3.0          4.5         1.5
    94           5.6         2.7          4.2         1.3
    
    [120 rows x 4 columns] before
    
    # 之后的index
    SepalLength  SepalWidth  PetalLength  PetalWidth
    0            5.1         3.8          1.5         0.3
    1            5.8         2.7          4.1         1.0
    2            4.6         3.4          1.4         0.3
    3            6.3         3.3          6.0         2.5
    4            5.1         3.4          1.5         0.2
    ..           ...         ...          ...         ...
    115          4.9         2.5          4.5         1.7
    116          5.0         3.0          1.6         0.2
    117          6.0         3.0          4.8         1.8
    118          5.4         3.0          4.5         1.5
    119          5.6         2.7          4.2         1.3
    
    [120 rows x 4 columns] after changed
    
    输出验证集的预测结果以及和真值的对比:
    ['virginica' 'setosa' 'versicolor' 'setosa' 'setosa' 'versicolor' 'setosa'
     'setosa' 'setosa' 'versicolor' 'virginica' 'versicolor' 'setosa'
     'virginica' 'setosa' 'virginica' 'versicolor' 'versicolor' 'virginica'
     'virginica' 'versicolor' 'versicolor' 'versicolor' 'virginica'
     'virginica' 'versicolor' 'setosa' 'setosa' 'setosa' 'virginica']
    0     True
    1     True
    2     True
    3     True
    4     True
    5     True
    6     True
    7     True
    8     True
    9     True
    10    True
    11    True
    12    True
    13    True
    14    True
    15    True
    16    True
    17    True
    18    True
    19    True
    20    True
    21    True
    22    True
    23    True
    24    True
    25    True
    26    True
    27    True
    28    True
    29    True
    Name: Name, dtype: bool
    

    更高级的建模方法:利用GridSearchCV这个模块!

    # 预测结果不准确,可以使用网格法优化,这里设定了模型训练的多个参数,利用sklearn里的模块可以自己测试并选择结果最好的一个模型?我还不是很懂
    parameters = {'splitter':('best','random')
                    ,'criterion':("gini","entropy")
                    ,"max_depth":[*range(1,10)]
                    ,'min_samples_leaf':[*range(1,50,5)]
                    ,'min_impurity_decrease':[*np.linspace(0,0.5,20)]
    }
    from sklearn.model_selection import GridSearchCV
    GS = GridSearchCV(clf, parameters, cv=10)
    GS.fit(Xtrain,Ytrain)
    print(GS.best_params_)
    print(GS.best_score_)
    {'criterion': 'gini', 'max_depth': 5, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'splitter': 'random'}
    0.9703296703296704
    
    score_ = clf.score(Xtest, Ytest)
    print(score_,'score') # 1.0 score
    

    明天想把图画出来嗷嗷,然后再试试回归树!

  • 相关阅读:
    Spring3+hibernate4+struts2整合的 过程中发生如下错误
    使用HQL语句的按照参数名字查询数据库信息的时候 “=:”和参数之间不能存在空格,否则会报错
    org.hibernate.service.classloading.spi.ClassLoadingException: Specified JDBC Driver com.mysql.jdbc.Driver class not found
    Java多线程编程:
    数据库连接池的工作原理
    Oracle数据库表的备份和数据表的删除操作
    数据库连接池
    Mysql登录异常的一个问题:
    2019年终总结
    设计模式入门-简单工厂模式
  • 原文地址:https://www.cnblogs.com/sweetsmartrange/p/13352474.html
Copyright © 2011-2022 走看看