zoukankan      html  css  js  c++  java
  • 【笔记】可视化模型误差之学习曲线

    学习曲线

    过拟合和欠拟合以及为什么要对分为训练数据集和测试数据集中绘制了模型复杂度曲线,那么如果还想用别的方法可视化过拟合和欠拟合的关系的话,可以使用学习曲线

    什么是学习曲线?

    学习曲线描述的就是随着训练样本的逐渐增多,算法训练出的模型的表现能力的变化

    具体实现体现一下

    (在notebook中)

    使用熟悉的配方熟悉的味道来绘制图形

      np.random.seed(666)
      x = np.random.uniform(-3.0,3.0,size=100)
      X = x.reshape(-1,1)
      y = 0.5 * x**2 + x + 2 + np.random.normal(0,1,size=100)
      plt.scatter(x,y)
    

    图像如下

    先对数据集进行分割,随机的种子设为10

      from sklearn.model_selection import train_test_split
      X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=10)
    

    此时在默认情况下,X_train中含有75个元素(默认取的25%的数据)

    那么学习曲线的绘制其实就是对于这75个训练数据,每一次都多一些训练数据来训练模型,以此来观察得到的模型在训练数据集和测试数据集上的表现是怎样的

    使用线性

    使用一个循环,从只给一个数据训练开始,到最后给75个数据训练,每一次都要训练一个模型,在循环中每一次都实例化一个lin_reg,然后对lin_reg进行fit关于训练的数据x和y的前i个元素,相应的训练完以后就开始预测,对预测结果来说,分成y_train_predict和y_test_predict,每一次预测的结果在相应的数据集上的对应的均方误差,运行程序以后,对性能变化的曲线进行绘制,x就是1到75的每个值,相应的y值就使用np.sqrt(对应的误差)来表示,对训练和测试都绘制出来

      from sklearn.metrics import mean_squared_error
      from sklearn.linear_model import LinearRegression
    
      train_score = []
      test_score = []
      for i in range(1,76):
          lin_reg = LinearRegression()
          lin_reg.fit(X_train[:i],y_train[:i])
    
          y_train_predict = lin_reg.predict(X_train[:i])
          train_score.append(mean_squared_error(y_train[:i],y_train_predict))
    
          y_test_predict = lin_reg.predict(X_test)
          test_score.append(mean_squared_error(y_test,y_test_predict))
    
      plt.plot([i for i in range(1,76)],np.sqrt(train_score),label="train")
      plt.plot([i for i in range(1,76)],np.sqrt(test_score),label="test")
      plt.legend()
    

    结果如下,这就是基于线性回归来说的样本数据的学习曲线

    大体趋势上,很明显的可以看出来在训练数据集上的误差是逐渐升高的,开始快,后面累计的小,相应的比较稳定,对于训练数据集来说,开始样本比较少的时候,误差非常的大,当样本到一定程度以后,误差就会减小,最后相对稳定,整体上学习曲线就是这样的趋势

    将上述代码提炼成一个函数plot_learning_curve,只需要传入机器学习算法和数据集即可得到对应的算法的学习曲线,不同之处在于对绘制图像的时候将范围进行了限定,y是0到4,x为0到X的训练数据集的数量

      def plot_learning_curve(algo,X_train,X_test,y_train,y_test):
          train_score = []
          test_score = []
          for i in range(1,len(X_train)+1):
              algo.fit(X_train[:i],y_train[:i])
    
              y_train_predict = algo.predict(X_train[:i])
              train_score.append(mean_squared_error(y_train[:i],y_train_predict))
    
              y_test_predict = algo.predict(X_test)
              test_score.append(mean_squared_error(y_test,y_test_predict))
        
          plt.plot([i for i in range(1,len(X_train)+1)],
                                np.sqrt(train_score),label="train")
          plt.plot([i for i in range(1,len(X_train)+1)],
                                np.sqrt(test_score),label="test")
          plt.legend()
          plt.axis([0,len(X_train)+1,0,4])
    

    调用函数以后,即可得到

      plot_learning_curve(LinearRegression(),X_train,X_test,y_train,y_test)
    

    结果如下(欠拟合)

    使用多项式回归

    多项式回归如下:

      from sklearn.pipeline import Pipeline
      from sklearn.preprocessing import StandardScaler
      from sklearn.preprocessing import PolynomialFeatures
    
      def PolynomialRegression(degree):
          return Pipeline([
              ("poly",PolynomialFeatures(degree=degree)),
              ("std_scaler",StandardScaler()),
              ("lin_reg",LinearRegression())
          ])
    

    传入degree为2,然后调用函数

      poly2_reg = PolynomialRegression(degree=2)
      plot_learning_curve(poly2_reg,X_train,X_test,y_train,y_test)
    

    结果如下(正合适)

    可以发现趋势是大致和线性一致的,不过线性最后的稳定位置比多项式的最后的稳定位置要大

    若传入degree为20

      poly20_reg = PolynomialRegression(degree=20)
      plot_learning_curve(poly20_reg,X_train,X_test,y_train,y_test)
    

    结果如下(过拟合)

    大致趋势还是一样,但是其最后稳定的时候训练数据和测试数据的差距还是比较大的,可以发现是在训练数据集上已经拟合的很好了,但是在测试数据集上,误差还是很大的,这种一般就是过拟合的情况

    对于欠拟合来说,比最佳的情况的稳定的位置要高一些,说明无论对于哪个数据集,其误差都是比较大的

    对于过拟合来说,训练数据集的误差是更低的,但是问题是测试数据集的误差比较大,离训练数据集比较远,这就说明这个模型的泛化能力比较弱,对于新的数据,误差是比较大的

  • 相关阅读:
    要开学了,暂时停更
    day13 IP包头分析 | 路由器原理 1
    day12 数据链路层 | 交换机基本命令
    day11 OSI与TCP-IP 5层协议 | 物理层相关知识
    day10 扫描与爆破
    day 09 简单渗透测试
    day07 PKI
    day07 域
    day06 WEB服务器 | FTP服务器
    day05 DHCP部署与安全 | DNS部署与安全
  • 原文地址:https://www.cnblogs.com/jokingremarks/p/14309319.html
Copyright © 2011-2022 走看看