zoukankan      html  css  js  c++  java
  • ML | Early Stopping是什么

    具体EarlyStopping的使用请参考官方文档源代码

    EarlyStopping是Callbacks的一种,callbacks用于指定在每个epoch开始和结束的时候进行哪种特定操作。Callbacks中有一些设置好的接口,可以直接使用,如’acc’, 'val_acc’, ’loss’ 和 ’val_loss’等等。
    EarlyStopping则是用于提前停止训练的callbacks。具体地,可以达到当训练集上的loss不在减小(即减小的程度小于某个阈值)的时候停止继续训练。
     

    为什么要用

    为了获得性能良好的神经网络,网络定型过程中需要进行许多关于所用设置(超参数)的决策。超参数之一是定型周期(epoch)的数量:亦即应当完整遍历数据集多少次(一次为一个epoch)?如果epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。

    早停法旨在解决epoch数量需要手动设置的问题。它也可以被视为一种能够避免网络发生过拟合的正则化方法(与L1/L2权重衰减和丢弃法类似)。

    根本原因就是因为继续训练会导致测试集上的准确率下降。
    那继续训练导致测试准确率下降的原因猜测可能是1. 过拟合 2. 学习率过大导致不收敛 3. 使用正则项的时候,Loss的减少可能不是因为准确率增加导致的,而是因为权重大小的降低。

    原理

    • 将数据分为训练集和验证集
    • 每个epoch结束后(或每N个epoch后): 在验证集上获取测试结果,随着epoch的增加,如果在验证集上发现测试误差上升,则停止训练;
    • 将停止之后的权重作为网络的最终参数。

    这种做法很符合直观感受,因为精度都不再提高了,在继续训练也是无益的,只会提高训练的时间。那么该做法的一个重点便是怎样才认为验证集精度不再提高了呢?并不是说验证集精度一降下来便认为不再提高了,因为可能经过这个Epoch后,精度降低了,但是随后的Epoch又让精度又上去了,所以不能根据一两次的连续降低就判断不再提高。一般的做法是,在训练的过程中,记录到目前为止最好的验证集精度,当连续10次Epoch(或者更多次)没达到最佳精度时,则可以认为精度不再提高了。

    直观理解

    最优模型是在垂直虚线的时间点保存下来的模型,即处理测试集时准确率最高的模型。

    为什么能减小过拟合

    当还未在神经网络运行太多迭代过程的时候,w参数接近于0,因为随机初始化w值的时候,它的值是较小的随机值。当你开始迭代过程,w的值会变得越来越大。到后面时,w的值已经变得十分大了。所以early stopping要做的就是在中间点停止迭代过程。我们将会得到一个中等大小的w参数,会得到与L2正则化相似的结果,选择了w参数较小的神经网络。

    Early Stopping的优缺点

    优点:只运行一次梯度下降,我们就可以找出w的较小值,中间值和较大值。而无需尝试L2正则化超级参数lambda的很多值。

    缺点:不能独立地处理以上两个问题,使得要考虑的东西变得复杂。举例如下:

    没有采取不同的方式来解决优化损失函数和降低方差这两个问题,而是用一种方法同时解决两个问题 ,结果就是要考虑的东西变得更复杂。之所以不能独立地处理,因为如果你停止了优化代价函数,你可能会发现代价函数的值不够小,同时你又不希望过拟合。

    EarlyStopping的使用与技巧

    一般是在model.fit函数中调用callbacks,fit函数中有一个参数为callbacks。注意这里需要输入的是list类型的数据,所以通常情况只用EarlyStopping的话也要是[EarlyStopping()]

    EarlyStopping的参数:

    • monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。但是因为笔者用的是5折交叉验证,没有单设验证集,所以只能用’acc’了。
    • min_delta:增大或减小的阈值,只有大于这个部分才算作improvement。这个值的大小取决于monitor,也反映了你的容忍程度。例如笔者的monitor是’acc’,同时其变化范围在70%-90%之间,所以对于小于0.01%的变化不关心。加上观察到训练过程中存在抖动的情况(即先下降后上升),所以适当增大容忍程度,最终设为0.003%。
    • patience:能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,比其稍大些设置patience。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。笔者在引入EarlyStopping之前就已经得到可以接受的结果了,EarlyStopping算是锦上添花,所以patience设的比较高,设为抖动epoch number的最大值。
    • mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。笔者的monitor是’acc’,所以mode=’max’。

    min_delta和patience都和“避免模型停止在抖动过程中”有关系,所以调节的时候需要互相协调。通常情况下,min_delta降低,那么patience可以适当减少;min_delta增加,那么patience需要适当延长;反之亦然。

    class RocAucMetricCallback(keras.callbacks.Callback):
        def __init__(self, predict_batch_size=1024):
            super(RocAucMetricCallback, self).__init__()
            self.predict_batch_size = predict_batch_size
     
        def on_batch_begin(self, batch, logs={}):
            pass
     
        def on_batch_end(self, batch, logs={}):
            pass
     
        def on_train_begin(self, logs={}):
            if not ('val_roc_auc' in self.params['metrics']):
                self.params['metrics'].append('val_roc_auc')
     
        def on_train_end(self, logs={}):
            pass
     
        def on_epoch_begin(self, epoch, logs={}):
            pass
     
        def on_epoch_end(self, epoch, logs={}):
            logs['roc_auc'] = float('-inf')
            if (self.validation_data):
                logs['roc_auc'] = roc_auc_score(self.validation_data[1],
                                                self.model.predict(self.validation_data[0],
                                                                   batch_size=self.predict_batch_size))
                print('ROC_AUC - epoch:%d - score:%.6f' % (epoch + 1, logs['roc_auc']))

        my_callbacks = [
            RocAucMetricCallback(),  # include it before EarlyStopping!
            EarlyStopping(monitor='roc_auc', patience=20, verbose=2, mode='max')
        ]
     
        mlp.fit(X_train_pre, y_train_pre,
                batch_size=512,
                epochs=500,
                class_weight="auto",
                callbacks=my_callbacks,
                validation_data=(X_train_pre_val, y_train_pre_val))

     

    扩充

    如果不用early stopping降低过拟合,另一种方法就是L2正则化,但需尝试L2正则化超级参数λ的很多值,个人更倾向于使用L2正则化,尝试许多不同的λ值。

    转载 https://blog.csdn.net/zwqjoy/article/details/86677030

  • 相关阅读:
    优化SQL查询:如何写出高性能SQL语句
    提高SQL执行效率的16种方法
    Spring Ioc DI 原理
    java内存泄漏
    转:js闭包
    LeetCode Best Time to Buy and Sell Stock III
    LeetCode Best Time to Buy and Sell Stock with Cooldown
    LeetCode Length of Longest Fibonacci Subsequence
    LeetCode Divisor Game
    LeetCode Sum of Even Numbers After Queries
  • 原文地址:https://www.cnblogs.com/gaodi2345/p/14705375.html
Copyright © 2011-2022 走看看