zoukankan      html  css  js  c++  java
  • sklearn CART决策树分类

    sklearn CART决策树分类

    决策树是一种常用的机器学习方法,可以用于分类和回归。同时,决策树的训练结果非常容易理解,而且对于数据预处理的要求也不是很高。

    理论部分

    比较经典的决策树是ID3、C4.5和CART,分别分析信息增益、增益率、基尼指数,总体思想是不断降低信息的不确定性,最后达到分类的目的。

    这里介绍的CART(Classification And Regression Tree)决策树选用基尼指数(Gini Index)来依次选择划分属性

    [Gini(D)=sum_{k=1}^{n} sum_{k_1 ot=k_2}p_{k_1}p_{k_2}=1-sum_{j=1}^{n}p_j^2 ]

    数据集的基尼指数越大,表示该数据集的信息量越大,可能性越多,越混乱;基尼指数越小,表示数据集越纯净。

    属性a的基尼指数定义为

    [Gini\_index(D,a)=sum_{v=1}^Vfrac{|D^v|}{D}Gini(D^v) ]

    表示确定属性(a)等于某个(v)后,数据集基尼指数的加权平均。

    每一轮求出各个属性的基尼指数,然后每次取最大属性的进行划分,这样总体信息的不确定性就会降低得最快。

    决策树生成前后,为了防止过拟合,还要使用剪枝(pruning)操作,这里不再展开。

    sklearn代码实现

    #coding=utf-8
    
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split
    from sklearn import datasets
    from sklearn import tree
    import numpy as np
    from sklearn.externals.six import StringIO
    import pydot
    
    def main():
        iris = datasets.load_iris() #典型分类数据模型
        #这里我们数据统一用pandas处理
        data = pd.DataFrame(iris.data, columns=iris.feature_names)
        data['class'] = iris.target
        
        #这里只取两类
        data = data[data['class']!=2]
        #为了可视化方便,这里取两个属性为例
        X = data[['sepal length (cm)','sepal width (cm)']]
        Y = data[['class']]
        #划分数据集
        X_train, X_test, Y_train, Y_test =train_test_split(X, Y)
        #创建决策树模型对象,默认为CART
        dt = tree.DecisionTreeClassifier()
        dt.fit(X_train, Y_train)
        
        #显示训练结果
        print dt.score(X_test, Y_test) #score是指分类的正确率
       
        #作图
        h = 0.02
        x_min, x_max = X.iloc[:, 0].min() - 1, X.iloc[:, 0].max() + 1
        y_min, y_max = X.iloc[:, 1].min() - 1, X.iloc[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                             np.arange(y_min, y_max, h))
        Z = dt.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)
        plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
    
        #做出原来的散点图
        class1_x = X.loc[Y['class']==0,'sepal length (cm)']
        class1_y = X.loc[Y['class']==0,'sepal width (cm)']
        l1 = plt.scatter(class1_x,class1_y,color='b',label=iris.target_names[0])
        class1_x = X.loc[Y['class']==1,'sepal length (cm)']
        class1_y = X.loc[Y['class']==1,'sepal width (cm)']
        l2 = plt.scatter(class1_x,class1_y,color='r',label=iris.target_names[1])
        plt.legend(handles = [l1, l2], loc = 'best')
        
        plt.grid(True)
        plt.show()
        #导出决策树的图片,需要配置graphviz,并且添加到环境变量
        dot_data = StringIO()
        tree.export_graphviz(dt, out_file=dot_data,feature_names=X.columns,  
                             class_names=['healthy','infected'],
                             filled=True, rounded=True,  
                             special_characters=True)
        graph = pydot.graph_from_dot_data(dot_data.getvalue())[0]
        graph.write_png("Iris.png")
        
    
    if __name__ == '__main__':
        main()
    

    测试结果

    0.92
    

    matlibplot显示

    blog

    Iris.png

    Iris

  • 相关阅读:
    12.SolrCloud原理
    11.SolrCloud集群环境搭建
    10.Solr4.10.3数据导入(DIH全量增量同步Mysql数据)
    9.Solr4.10.3数据导入(post.jar方式和curl方式)
    Java程序设计之最大公约数和最小公倍数
    Java程序设计之正则表达式
    Java程序设计之整数分解
    Java程序设计之裴波拉切那数列(兔子一年的数量)
    Java并发编程实例(synchronized)
    Java程序设计之合租房synchronized(二)
  • 原文地址:https://www.cnblogs.com/fanghao/p/7517671.html
Copyright © 2011-2022 走看看