zoukankan      html  css  js  c++  java
  • sklearn CART决策树分类

    sklearn CART决策树分类

    决策树是一种常用的机器学习方法,可以用于分类和回归。同时,决策树的训练结果非常容易理解,而且对于数据预处理的要求也不是很高。

    理论部分

    比较经典的决策树是ID3、C4.5和CART,分别分析信息增益、增益率、基尼指数,总体思想是不断降低信息的不确定性,最后达到分类的目的。

    这里介绍的CART(Classification And Regression Tree)决策树选用基尼指数(Gini Index)来依次选择划分属性

    [Gini(D)=sum_{k=1}^{n} sum_{k_1 ot=k_2}p_{k_1}p_{k_2}=1-sum_{j=1}^{n}p_j^2 ]

    数据集的基尼指数越大,表示该数据集的信息量越大,可能性越多,越混乱;基尼指数越小,表示数据集越纯净。

    属性a的基尼指数定义为

    [Gini\_index(D,a)=sum_{v=1}^Vfrac{|D^v|}{D}Gini(D^v) ]

    表示确定属性(a)等于某个(v)后,数据集基尼指数的加权平均。

    每一轮求出各个属性的基尼指数,然后每次取最大属性的进行划分,这样总体信息的不确定性就会降低得最快。

    决策树生成前后,为了防止过拟合,还要使用剪枝(pruning)操作,这里不再展开。

    sklearn代码实现

    #coding=utf-8
    
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split
    from sklearn import datasets
    from sklearn import tree
    import numpy as np
    from sklearn.externals.six import StringIO
    import pydot
    
    def main():
        iris = datasets.load_iris() #典型分类数据模型
        #这里我们数据统一用pandas处理
        data = pd.DataFrame(iris.data, columns=iris.feature_names)
        data['class'] = iris.target
        
        #这里只取两类
        data = data[data['class']!=2]
        #为了可视化方便,这里取两个属性为例
        X = data[['sepal length (cm)','sepal width (cm)']]
        Y = data[['class']]
        #划分数据集
        X_train, X_test, Y_train, Y_test =train_test_split(X, Y)
        #创建决策树模型对象,默认为CART
        dt = tree.DecisionTreeClassifier()
        dt.fit(X_train, Y_train)
        
        #显示训练结果
        print dt.score(X_test, Y_test) #score是指分类的正确率
       
        #作图
        h = 0.02
        x_min, x_max = X.iloc[:, 0].min() - 1, X.iloc[:, 0].max() + 1
        y_min, y_max = X.iloc[:, 1].min() - 1, X.iloc[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                             np.arange(y_min, y_max, h))
        Z = dt.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)
        plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
    
        #做出原来的散点图
        class1_x = X.loc[Y['class']==0,'sepal length (cm)']
        class1_y = X.loc[Y['class']==0,'sepal width (cm)']
        l1 = plt.scatter(class1_x,class1_y,color='b',label=iris.target_names[0])
        class1_x = X.loc[Y['class']==1,'sepal length (cm)']
        class1_y = X.loc[Y['class']==1,'sepal width (cm)']
        l2 = plt.scatter(class1_x,class1_y,color='r',label=iris.target_names[1])
        plt.legend(handles = [l1, l2], loc = 'best')
        
        plt.grid(True)
        plt.show()
        #导出决策树的图片,需要配置graphviz,并且添加到环境变量
        dot_data = StringIO()
        tree.export_graphviz(dt, out_file=dot_data,feature_names=X.columns,  
                             class_names=['healthy','infected'],
                             filled=True, rounded=True,  
                             special_characters=True)
        graph = pydot.graph_from_dot_data(dot_data.getvalue())[0]
        graph.write_png("Iris.png")
        
    
    if __name__ == '__main__':
        main()
    

    测试结果

    0.92
    

    matlibplot显示

    blog

    Iris.png

    Iris

  • 相关阅读:
    关于js计算非等宽字体宽度的方法
    [NodeJs系列]聊一聊BOM
    Vue.js路由管理器 Vue Router
    vue 实践技巧合集
    微任务、宏任务与Event-Loop
    事件循环(EventLoop)的学习总结
    Cookie、Session和LocalStorage
    MySQL 树形结构 根据指定节点 获取其所在全路径节点序列
    MySQL 树形结构 根据指定节点 获取其所有父节点序列
    MySQL 创建函数报错 This function has none of DETERMINISTIC, NO SQL, or READS SQL DATA in its declaration and binary logging is enabled (you *might* want to use the less safe log_bin_trust_function_creators
  • 原文地址:https://www.cnblogs.com/fanghao/p/7517671.html
Copyright © 2011-2022 走看看