转自:https://blog.csdn.net/weixin_40446557/article/details/103387629
1.介绍
结合交叉验证法,可以防止模型过早拟合。在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题。
注:需要将数据集分为训练集和验证集。
早停法主要是训练时间和泛化错误之间的权衡。
//........... early_stopping(valid_loss, model) if early_stopping.early_stop: print("Early stopping") break //...........
链接中提供的代码早停是针对验证集上的损失。
2.解释
知乎 https://www.zhihu.com/question/59201590/answer/167392763
经过每个神经元,在给定激活函数的情况下,它的激活能力是和参数有关系的。
网络一开始训练的时候会赋值权重为较小值,它的拟合能力弱,基本为线性,随着网络的训练,权重会增大,那么就早停,使得网络的参数不那么复杂,降低它的过拟合程度,降低拟合训练集的能力。
3.skorch中的earlystopping
skorch.callbacks.EarlyStopping(patience=args.earlystop) class EarlyStopping(Callback): def __init__( self, monitor='valid_loss', #默认监控的是valid集上的损失 patience=5, threshold=1e-4, threshold_mode='rel', lower_is_better=True, sink=print, ):
4.valid_loss计算
https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/ 的例子:
with torch.no_grad(): for i, data in prog_bar: counter += 1 data, target = data[0].to(device), data[1].to(device) total += target.size(0) outputs = model(data) loss = criterion(outputs, target) val_running_loss += loss.item() _, preds = torch.max(outputs.data, 1) val_running_correct += (preds == target).sum().item() val_loss = val_running_loss / counter val_accuracy = 100. * val_running_correct / total return val_loss, val_accuracy
其中val_loss 是先计算所有batch的总和,然后在batch数目上取均值。
https://github.com/Cai-Yichao/torch_backbones/blob/e3d4850603a795cee0710bba8f83db74f5a70d68/train.py#L96 例子中:
with torch.no_grad(): for index, (inputs, targets) in enumerate(test_loader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) loss_curr = criterion(outputs, targets) loss += loss_curr.item() eval_loss = loss/(index+1)
其中val_loss 是计算在每个batch上的损失。
https://blog.csdn.net/weixin_40446557/article/details/103387629 给的例子:
for data, target in valid_loader: # forward pass: compute predicted outputs by passing inputs to the model output = model(data) # calculate the loss loss = criterion(output, target) # record validation loss valid_losses.append(loss.item()) valid_loss = np.average(valid_losses) early_stopping(valid_loss, model)
其中val_loss 是计算在每个batch上的损失。
以上的三个例子都是valid_loss在batch水平的损失。