zoukankan      html  css  js  c++  java
  • 带你一文读懂Batch Normalization

    带你一文读懂Batch Normalization

    Batch Normalization的提出就是为了解决深度学习中一个很接近本质的问题:为什么深度神经网络随着深度增加,训练起来越来越困难,收敛也越来越慢?。当然,还有其他的一些方法也是用来解决这个问题的,例如:ResNet、ReLU激活函数等等。不同于ResNet引入shortcut以削弱链式求导过长而带来的网络退化问题,BN着眼于中间层输出的分布,力图从这个方面解决前面的问题。

    1. Internal Covariate Shift

     BN论文的原文叫做 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift。从论文名字上面可以看出来,BN就是要解决训练时Internal Covariate Shift这个问题。
     随着Mini-Batch SGD的引入,这种训练策略逐渐成为深度学习训练网络模型的主流训练策略。而Batch Normalization正是基于这种训练策略的。
     Covariate Shift则是指,如果ML系统中实例集合<X, Y>中的输入X的分布老是发生变化,这就不符合IID假设。通常来说就是训练集和测试集样本分布不一致,特别是现在采用batch的方式训练,随着大量新样本的加入,导致训练集和测试集越来越不一样。
     对于神经网络来说,深度神经网络中间会有许多隐含层,这些隐含层会输出很多的隐含变量,而由于各个隐含层的参数在不断变化,导致每个隐含层的输出分布也会不断变化,所以每个隐含层都会面临着Covariate Shift的问题。这些隐含层由于参数不断改变,导致其输出分布也在不断发生变化的问题就叫做Internal Covariate Shift。

    2. Batch Normalization原理

    这个部分我会结合公式原理和Pytorch代码来从原理和实践中讲解。

     之前的研究表明,如果在图像处理中对输入图像进行白化(Whiten)操作,即将输入数据分布变换到0均值,单位方差的正态分布,那么神经网络会较快收敛。当然,上述的白化操作存在两个重要的问题:

    1. 只是对输入图像进行了白化操作,还是不能解决隐含层的Covariate Shift问题。随着输入不断地被传入更深的隐含层,隐含层的输出仍然是不断变化的。
    2. 原图像并不一定是正态分布的,使用白化操作强制将原图像分布转变为标准正态分布,虽然会加快网络收敛,但是也改变了原图像的分布。

    2.1 Batch Normalization公式详解

     BN的公式如下:
    x ^ ( k ) = x ( k ) − E [ x ( k ) ] V a r [ x ( k ) ] hat{x}^{(k)} = frac{x^{(k)} - E[x^{(k)}]}{sqrt{Var[x^{(k)}]}} x^(k)=Var[x(k)] x(k)E[x(k)]
    y ( k ) = γ ( k ) x ^ ( k ) + β ( k ) y^{(k)} = gamma^{(k)} hat{x}^{(k)} + eta^{(k)} y(k)=γ(k)x^(k)+β(k)
     其中, γ ( k ) gamma^{(k)} γ(k) β ( k ) eta^{(k)} β(k)是两个可学习参数; x ( k ) x^{(k)} x(k)代表的是第k个维度的输入,这里需要注意的是原论文中上标代表的是维度-demension,下标代表的才是batch中第i个输入,例如上面公式的 x ( k ) = ( x 0 ( k ) , x 1 ( k ) , x 2 ( k ) … x m ( k ) ) x^{(k)}= (x_0^{(k)}, x_1^{(k)}, x_2^{(k)}dots x_m^{(k)}) x(k)=(x0(k),x1(k),x2(k)xm(k)) E [ x ( k ) ] E[x^(k)] E[x(k)]代表均值; V a r [ x ( k ) ] Var[x^{(k)}] Var[x(k)]代表方差。
     这里加入 γ gamma γ β eta β这两个可学习的参数就是为了利用仿射变换,使得隐含层的输出不仅仅是正态分布。

    2.2 Batch Normalization2d 举个栗子

     上面的公式看完过后还是觉得比较抽象怎么办?别慌!我用图像处理中最常用的Batch Normalization2d这个层来给大家举个栗子。
     可能有很多同学在自己写代码的时候,会有一个疑问:Pytorch的BN2d层输入是(Batch, Channel, Height, Width)这个亚子的,Batch不算维度,那应该是BN3d呀,为什么这些框架里面都用叫它BN2d呢?注意,在深度学习图像处理计算维度的时候一般会将Height和Width当做一个维度,所以这里其实只有Channel、(Height * Width)两个维度,所以它就叫BN2d。
     有了上面的知识,下面用一个具体的例子来深入理解BN2d的参数(灵魂画手出场):
    在这里插入图片描述
     我们现在有一个Batch中N幅图像,输入进了BN2d。而BN2d在Pytorch代码中有着这4个参数:

    • weights:就是上面公式的 γ gamma γ,通过反向传播学习。
    • bias:就是上面公式的 β eta β,通过反向传播学习。
    • running_mean:就是数据的均值,不过由于是Mini-Batch策略训练的,所以需要采用指数加权平均更新。
    • running_var:数据的方差,同上,也需要指数加权平均更新。

     首先来看running_mean(即 V a r [ x ( k ) ] Var[sqrt{x^{(k)}}] Var[x(k) ])参数,从上面的公式可以看出来我们需要计算每个channel的 V a r [ x ( k ) ] Var[sqrt{x^{(k)}}] Var[x(k) ]。例如计算第一个通道的 V a r [ x ( 0 ) ] Var[sqrt{x^{(0)}}] Var[x(0) ]参数,就需要把这N幅图像的第一个通道合并到一起组成一个(N, H, W)的方框,由于H,W其实是一个参数,所以这个方块又可以写成(N, H*W)。把这个方块所有元素拿来求一个均值,就得到了 V a r [ x ( 0 ) ] Var[sqrt{x^{(0)}}] Var[x(0) ]。其他通道同理,最终我们就可以得到一个大小为C的 V a r [ x ] Var[sqrt{x}] Var[x ]向量。running_var和running_mean算法类似,不过是改成求方差即可。
     再来看看weights和bias这两个learnable的参数,同样的每个通道都有一个 γ ( k ) , β ( k ) gamma^{(k)}, eta^{(k)} γ(k),β(k),所以weights和bias也是一个大小为C的向量。
     如果想了解指数加权平均的原理请移步我的另一篇博客一文带你入门深度学习优化算法

  • 相关阅读:
    【读书笔记】 —— 《数学女孩》
    【读书笔记】 —— 《数学女孩》
    《论语》《大学》《中庸》和孟子
    《论语》《大学》《中庸》和孟子
    零点定理、介值定理
    java学习笔记(3)——面向对象
    linux下的文件操作——批量重命名
    Java学习笔记(4)——JavaSE
    java学习笔记(5)——内部类
    学生管理系统调试——实时错误(实时错误“424”“5”“91”)
  • 原文地址:https://www.cnblogs.com/lsl1229840757/p/14122572.html
Copyright © 2011-2022 走看看