zoukankan      html  css  js  c++  java
  • Pytorch中的model.train()与model.eval()原理和实验分析

    Pytorch中的model.train()与model.eval()

     最近在跑实验代码, 发现对于Pytorch中的model.train()与model.eval()两种模式的理解只是停留在理论知识的层面,缺少了实操的经验。下面博主将从理论层面与实验经验这两个方面总结model.train()与model.eval()的区别和坑点。

    0. 理论区别

     首先需要明确的是这两个模式会影响DropoutBatchNormal这两个Module的行为。

    0.1 Dropout

     在train模式下,Dropout会根据概率p随机将部分输出变为0,即让一些神经元失活。

     在eval模式下, Dropout将不再会将部分输出变为0,即相当于从模型中移除了该Module。

    0.2 BatchNorm

     首先,有关BN的理论学习请移步我另一篇博客带你一文读懂Batch Normalization。这里我只是讨论train与eval模式下BN的行为差异。

     首先需要明确BN的行为由training属性(这里就是通过model.train()设置)和track_running_stats属性控制。

     在BN中track_running_stats属性默认为True,在train模式下,forward的时候统计running_mean, running_var并将其作为 μ , σ mu, sigma μ,σ,其统计公式如下图所示,在eval模式下,利用前面统计的均值和方差作为 μ , σ mu, sigma μ,σ用于inference。其中running_mean, running_var的更新规则如下,可以简单地理解为计算了多个batch的平均值,要想了解指数平均具体原理可以一步博主另一篇博客一文带你入门深度学习优化算法
    在这里插入图片描述
     控制逻辑如下:
    在这里插入图片描述
     下面根据train和track_running_stats的两种状态两两组合为四类,分类讨论:

    敲重点!!!:

    1. training=True, track_running_stats=True:即训练模式下跟踪运行状态, 此时前向传播使用的 σ , μ sigma, mu σ,μ为running_mean、running_var。
    2. training=True, track_running_stats=False:即训练模式下不跟踪运行状态,此时前向传播使用的 σ , μ sigma, mu σ,μ为None,也就是当前batch的均值和方差
    3. training=False, track_running_stats=True:即评估模式下跟踪运行状态,此时前向传播使用的 σ , μ sigma, mu σ,μ为running_mean、running_var。
    4. training=False, track_running_stats=False:即评估模式下不跟踪运行状态,此时前向传播使用的 σ , μ sigma, mu σ,μ为running_mean、running_var。

     所以使用当前batch的统计量的情况只有一个,就是训练模式下不跟踪运行状态。

    1. 实验经验1

    1.1 实验配置

     要从实验经验上面讲两种模式,就需要从一个具体的实验开始。首先说一下我的实验配置:

    • 模型中全部使用默认参数创建BN,即此时momentum=0.1,track_running_stats=True,也就是大概跟踪10个batch的平均值,没有Dropout
    • batch size为12
    • 数据集分为train split和test split,训练在train split上,测试精度在test split之上得到
    • 使用了在ImageNet之上pretrain的模型,即此时BN已经有了running_mean和running_var了
    • 使用了ImageNet的mean和std对数据进行规范化处理
    • 数据集为CVUSA

    1.2 实验现象

     最后的实验结果发现模型在eval模式之下训练比在train模式之下训练的测试精度更好

    1.3 原因分析

     通过上述理论和实验配置我们可以知道,实验现象就是在说,训练中,模型使用ImageNet中的running_mean和running_var作为 μ , σ mu, sigma μ,σ的效果会好于在训练中不断利用数据集CVUSA去更新的running_mean和running_var作为 μ , σ mu, sigma μ,σ的效果。

     导致上述现象的原因,中心问题就是train模式下利用CVUSA更新ImageNet估计的running_mean和running_var不够稳定,所以解决策略需要从稳定running_mean和running_var上入手。

    • batch size太小,因为bs太小会让统计更新的running_mean和running_var中的噪声比较大。
    • momentum太大,可以调小momentum以估计更多的minibatch,即让running_mean和running_var从更多的minibatch中获得,从而更加稳定。

    1.4 实验验证

    1. 在超算中, 给pretrained的BN重新初始化,再进行试验。

     在这个实验中效果不太好

    1. 重新初始化后再调小momentum,即估计更多的minibatch。

     在这个实验中效果不太好

    1. 真实原因是博主做geo-loc实验,一次前向传播用了地面视角和空域视角两张图片作为一个input pair,而两种视角图片的均值与方差差异很大,导致共享参数的孪生神经网络的BN层无法很好地得到数据集方差与均值的估计值。

    2. 实验经验2

     没有做过这个实验,但是看到论坛中在讨论有关问题,这里就把它放出来。

    2.1 实验现象

     训练模型之后,开train模式进行验证比开eval模式进行验证精度更高。

    2.2 原因分析

     将上述实验现象翻译过来就是,模型开train模式进行评估,即用当前batch的均值与方差去继续更新running_mean,running_var作为 μ , σ mu, sigma μ,σ,这样的效果反而比开eval模式固定统计出来的running_mean,running_var作为 μ , σ mu, sigma μ,σ更好。导致上述问题的原因可能有以下这些:

    • batchsize较小导致的估计不稳定
    • momentum较大导致的估计不稳定
    • 训练集分布与测试集分布不太一致(一般会出现在那些网络收集不出名的小数据集上面,比较经典的数据集应该不太会是这个问题)、

     遇到上述问题的时候一般要先从原因下手分析再解决,但是如果实在难以定位问题(因为虽然底层逻辑的原因一样,但是每个不同的task为什么会导致这样的原因,还是由具体的task决定的),所以博主也提供了以下两种快速解决方案:

    • √ 快速解决方案1:把BN层的track_running_stats属性设置为False再重新训练,这样可以不追踪整个数据集的统计量,而是直接利用当前batch的计算的 μ , σ mu, sigma μ,σ (不过个人不是很建议这么做, bs=1的时候不能估计当前batch的均值和方差)
    • √ 快速解决方案2:直接去掉BN和dropout,不过这样可能会掉一些点

    PS:虽然快速解决方案能快速奏效,但是还是有限制条件的,所以建议还是利用博主提供的分析范式,再结合不同具体的task进行分析,从根本上找到问题,然后再解决。

  • 相关阅读:
    如何做实时监控?—— 参考 Spring Boot 实现
    如何做实时监控?—— 参考 Spring Boot 实现
    spring boot application properties配置详解
    Jrebel 6.2.1破解
    智能社-JS -wiki
    hibernate.properties
    Tomcat 的 socket bind failed的解决方法
    js 排序 SORT 各种方法
    java EE 如何使用Eclipse启动一个项目
    2016-06-06 数组的几个重要方法
  • 原文地址:https://www.cnblogs.com/lsl1229840757/p/14381485.html
Copyright © 2011-2022 走看看