zoukankan      html  css  js  c++  java
  • 【笔记】CART与决策树中的超参数

    CART与决策树中的超参数

    先前的决策树其实应该称为CART

    CART的英文是Classification and regression tree,全称为分类与回归树,其是在给定输入随机变量X条件下输出随机变量Y的条件概率分布的学习方法,就是假设决策树是二叉树,内部结点特征的取值为“是”和“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支,其可以解决分类问题,又可以解决回归问题,特点就是根据某一个维度d和某一个阈值v进行二分

    在sklearn中的决策树都是CART的方式实现的

    回顾前面,不难发现,对于之前的决策树的划分模拟,平均而言,预测的复杂度是O(logm),其中m为样本个数,这样创建决策树的训练的过程的复杂度是O(nmlogm),是很高的,n为维度数,还有一个重要的问题就是很容易产生过拟合

    那么因为种种原因,在创建决策树的时候需要进行剪枝,即降低复杂度,解决过拟合的操作

    剪枝的操作有很多种,实际上就是对各种参数进行平衡

    通过具体操作来体现一下

    (在notebook中)

    加载好需要的包,使用make_moons生成虚拟数据,将数据情况绘制出来

      import numpy as np
      import matplotlib.pyplot as plt
      from sklearn import datasets
    
      X,y = datasets.make_moons(noise=0.25,random_state=666)
    
      plt.scatter(X[y==0,0],X[y==0,1])
      plt.scatter(X[y==1,0],X[y==1,1])
    

    图像如下

    然后使用DecisionTreeClassifier这个类,不进行任何限定操作,进行实例化操作以后训练数据

      from sklearn.tree import DecisionTreeClassifier
    
      dt_clf = DecisionTreeClassifier()
      dt_clf.fit(X,y)
    

    绘制函数,绘制图像

    from matplotlib.colors import ListedColormap
    def plot_decision_boundary(model, axis):
    
        x0,x1 = np.meshgrid(  
            np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
            np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
        )
        X_new = np.c_[x0.ravel(),x1.ravel()]
        
        y_predict = model.predict(X_new)
        zz = y_predict.reshape(x0.shape)
        
        custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])
    
        plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
    
      plot_decision_boundary(dt_clf,axis=[-1.5,2.5,-1.0,1.5])
      plt.scatter(X[y==0,0],X[y==0,1])
      plt.scatter(X[y==1,0],X[y==1,1])
    

    图像如下(可以发现决策边界很不规则,很明显产生了过拟合)

    设置参数max_depth为2,限制最大深度为2,然后训练数据并绘制模型

      dt_clf2 = DecisionTreeClassifier(max_depth=2)
      dt_clf2.fit(X,y)
    
      plot_decision_boundary(dt_clf2,axis=[-1.5,2.5,-1.0,1.5])
      plt.scatter(X[y==0,0],X[y==0,1])
      plt.scatter(X[y==1,0],X[y==1,1])
    

    图像如下(可以发现过拟合已经不明显了,很清晰的表示出了边界)

    还可以设置参数min_samples_split为10,限制一个节点的样本数最小为10,然后训练数据并绘制模型

      dt_clf3 = DecisionTreeClassifier(min_samples_split=10)
      dt_clf3.fit(X,y)
    
      plot_decision_boundary(dt_clf3,axis=[-1.5,2.5,-1.0,1.5])
      plt.scatter(X[y==0,0],X[y==0,1])
      plt.scatter(X[y==1,0],X[y==1,1])
    

    图像如下(过拟合程度低)

    还可以设置参数min_samples_leaf为6,限制一个叶子节点样本数至少为6,然后训练数据并绘制模型

      dt_clf4 = DecisionTreeClassifier(min_samples_leaf=6)
      dt_clf4.fit(X,y)
    
      plot_decision_boundary(dt_clf4,axis=[-1.5,2.5,-1.0,1.5])
      plt.scatter(X[y==0,0],X[y==0,1])
      plt.scatter(X[y==1,0],X[y==1,1])
    

    图像如下

    还可以设置参数max_leaf_nodes为4,限制叶子节点最多为4,然后训练数据并绘制模型

      dt_clf5 = DecisionTreeClassifier(max_leaf_nodes=4)
      dt_clf5.fit(X,y)
    
      plot_decision_boundary(dt_clf5,axis=[-1.5,2.5,-1.0,1.5])
      plt.scatter(X[y==0,0],X[y==0,1])
      plt.scatter(X[y==1,0],X[y==1,1])
    

    图像如下

    可以发现,对参数进行适当的修改可以很好的解决过拟合的问题,但是需要注意不要调节到欠拟合,那么寻找到合适的参数就可以使用网格搜索的方式

  • 相关阅读:
    jQuery ajax中支持的数据类型
    行内元素与块级元素
    本地连接无法加载远程访问连接管理器服务,错误711
    SQL Server 两种判断表名是否存在且删除的方式
    SQL Server 2008 修改表名
    MySql5.1在Win7下的安装与重装问题的解决
    JavaScript关闭浏览器
    SQL Server 添加一条数据获取自动增长列的几种方法
    获取当前程序运行目录
    字符串的判断与替换
  • 原文地址:https://www.cnblogs.com/jokingremarks/p/14341963.html
Copyright © 2011-2022 走看看