zoukankan      html  css  js  c++  java
  • early stopping早停_pytorch学习

    转自:https://blog.csdn.net/weixin_40446557/article/details/103387629

    1.介绍

    结合交叉验证法,可以防止模型过早拟合。在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题。

    注:需要将数据集分为训练集和验证集。

    早停法主要是训练时间泛化错误之间的权衡。 

            //...........
            early_stopping(valid_loss, model)
            
            if early_stopping.early_stop:
                print("Early stopping")
                break
            //...........

    链接中提供的代码早停是针对验证集上的损失。

    2.解释

    知乎 https://www.zhihu.com/question/59201590/answer/167392763

    经过每个神经元,在给定激活函数的情况下,它的激活能力是和参数有关系的。

    网络一开始训练的时候会赋值权重为较小值,它的拟合能力弱,基本为线性,随着网络的训练,权重会增大,那么就早停,使得网络的参数不那么复杂,降低它的过拟合程度,降低拟合训练集的能力。

    3.skorch中的earlystopping

    skorch.callbacks.EarlyStopping(patience=args.earlystop)
    
    class EarlyStopping(Callback):
            def __init__(
                self,
                monitor='valid_loss', #默认监控的是valid集上的损失
                patience=5,
                threshold=1e-4,
                threshold_mode='rel',
                lower_is_better=True,
                sink=print,
        ):

    4.valid_loss计算

    https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/ 的例子:

        with torch.no_grad():
            for i, data in prog_bar:
                counter += 1
                data, target = data[0].to(device), data[1].to(device)
                total += target.size(0)
                outputs = model(data)
                loss = criterion(outputs, target)
                
                val_running_loss += loss.item()
                _, preds = torch.max(outputs.data, 1)
                val_running_correct += (preds == target).sum().item()
            
            val_loss = val_running_loss / counter
            val_accuracy = 100. * val_running_correct / total
            return val_loss, val_accuracy

    其中val_loss 是先计算所有batch的总和,然后在batch数目上取均值。

    https://github.com/Cai-Yichao/torch_backbones/blob/e3d4850603a795cee0710bba8f83db74f5a70d68/train.py#L96 例子中:

        with torch.no_grad():
            for index, (inputs, targets) in enumerate(test_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss_curr = criterion(outputs, targets)
    
                loss += loss_curr.item()
         
        eval_loss = loss/(index+1)

     其中val_loss 是计算在每个batch上的损失。

    https://blog.csdn.net/weixin_40446557/article/details/103387629 给的例子:

            for data, target in valid_loader:
                # forward pass: compute predicted outputs by passing inputs to the model
                output = model(data)
                # calculate the loss
                loss = criterion(output, target)
                # record validation loss
                valid_losses.append(loss.item())
    
           valid_loss = np.average(valid_losses)
           early_stopping(valid_loss, model)

     其中val_loss 是计算在每个batch上的损失。

    以上的三个例子都是valid_loss在batch水平的损失。

  • 相关阅读:
    发个小程序希望有人需要(操作摄像头)
    (转)Qt中translate、tr关系 与中文问题
    VS2008代码自动对齐
    (转)Qt国际化(源码含中文时)的点滴分析
    (转)Bibtex使用方法
    (转)new,operate new和placement new
    (转)C++中的虚函数表
    (转)QString 与中文问题
    (转)static_cast, dynamic_cast, const_cast探讨
    试试
  • 原文地址:https://www.cnblogs.com/BlueBlueSea/p/14563880.html
Copyright © 2011-2022 走看看