zoukankan      html  css  js  c++  java
  • 干货 | 这可能全网最好的BatchNorm详解

    文章来自:公众号【机器学习炼丹术】。求关注~

    其实关于BN层,我在之前的文章“梯度爆炸”那一篇中已经涉及到了,但是鉴于面试经历中多次问道这个,这里再做一个更加全面的讲解。

    Internal Covariate Shift(ICS)

    Batch Normalization的原论文作者给了Internal Covariate Shift一个较规范的定义:在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift。

    这里做一个简单的数学定义,对于全链接网络而言,第i层的数学表达可以体现为:
    (Z^i=W^i imes input^i+b^i)
    (input^{i+1}=g^i(Z^i))

    • 第一个公式就是一个简单的线性变换;
    • 第二个公式是表示一个激活函数的过程。

    【怎么理解ICS问题】
    我们知道,随着梯度下降的进行,每一层的参数(W^i,b^i)都会不断地更新,这意味着(Z^i)的分布也不断地改变,从而(input^{i+1})的分布发生了改变。这意味着,除了第一层的输入数据不改变,之后所有层的输入数据的分布都会随着模型参数的更新发生改变,而每一层就要不停的去适应这种数据分布的变化,这个过程就是Internal Covariate Shift。

    BN解决的问题

    【ICS带来的收敛速度慢】
    因为每一层的参数不断发生变化,从而每一层的计算结果的分布发生变化,后层网络不断地适应这种分布变化,这个时候会让整个网络的学习速度过慢。

    【梯度饱和问题】
    因为神经网络中经常会采用sigmoid,tanh这样的饱和激活函数(saturated actication function),因此模型训练有陷入梯度饱和区的风险。解决这样的梯度饱和问题有两个思路:第一种就是更为非饱和性激活函数,例如线性整流函数ReLU可以在一定程度上解决训练进入梯度饱和区的问题。另一种思路是,我们可以让激活函数的输入分布保持在一个稳定状态来尽可能避免它们陷入梯度饱和区,这也就是Normalization的思路。

    Batch Normalization

    batchNormalization就像是名字一样,对一个batch的数据进行normalization。

    现在假设一个batch有3个数据,每个数据有两个特征:(1,2),(2,3),(0,1)

    如果做一个简单的normalization,那么就是计算均值和方差,把数据减去均值除以标准差,变成0均值1方差的标准形式。

    对于第一个特征来说:
    (mu=frac{1}{3}(1+2+0)=1)
    (sigma^2=frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67)

    【通用公式】
    (mu=frac{1}{m}sum_{i=1}^m{Z})
    (sigma^2=frac{1}{m}sum_{i=1}^m(Z-mu))
    (hat{Z}=frac{Z-mu}{sqrt{sigma^2+epsilon}})

    • 其中m表示一个batch的数量。
    • (epsilon)是一个极小数,防止分母为0。

    目前为止,我们做到了让每个特征的分布均值为0,方差为1。这样分布都一样,一定不会有ICS问题

    如同上面提到的,Normalization操作我们虽然缓解了ICS问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。每一层的分布都相同,所有任务的数据分布都相同,模型学啥呢

    【0均值1方差数据的弊端】

    1. 数据表达能力的缺失;
    2. 通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域。(线性区域和饱和区域都不理想,最好是非线性区域)

    为了解决这个问题,BN层引入了两个可学习的参数(gamma)(eta),这样,经过BN层normalization的数据其实是服从(eta)均值,(gamma^2)方差的数据。

    所以对于某一层的网络来说,我们现在变成这样的流程:

    1. (Z=W imes input^i+b)
    2. (hat{Z}=gamma imes frac{Z-mu}{sqrt{sigma^2+epsilon}}+eta)
    3. (input^{i+1}=g(hat{Z}))

    (上面公式中,省略了(i),总的来说是表示第i层的网络层产生第i+1层输入数据的过程)

    测试阶段的BN

    我们知道BN在每一层计算的(mu)(sigma^2) 都是基于当前batch中的训练数据,但是这就带来了一个问题:我们在预测阶段,有可能只需要预测一个样本或很少的样本,没有像训练样本中那么多的数据,这样的(sigma^2)(mu)要怎么计算呢?

    利用训练集训练好模型之后,其实每一层的BN层都保留下了每一个batch算出来的(mu)(sigma^2).然后呢利用整体的训练集来估计测试集的(mu_{test})(sigma_{test}^2)
    (mu_{test}=E(mu_{train}))
    (sigma_{test}^2=frac{m}{m-1}E(sigma_{train}^2))
    然后再对测试机进行BN层:

    当然,计算训练集的(mu)(simga)的方法除了上面的求均值之外。吴恩达老师在其课程中也提出了,可以使用指数加权平均的方法。不过都是同样的道理,根据整个训练集来估计测试机的均值方差。

    BN层的好处有哪些

    1. BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度。
      BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。

    2. BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题
      通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习(gamma)(eta) 又让数据保留更多的原始信息。

    3. BN具有一定的正则化效果
      在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音

    BN与其他normalizaiton的比较

    【weight normalization】
    Weight Normalization是对网络权值进行normalization,也就是L2 norm。

    相对于BN有下面的优势:

    1. WN通过重写神经网络的权重的方式来加速网络参数的收敛,不依赖于mini-batch。BN因为以来minibatch所以BN不能用于RNN网路,而WN可以。而且BN要保存每一个batch的均值方差,所以WN节省内存;
    2. BN的优点中有正则化效果,但是添加噪音不适合对噪声敏感的强化学习、GAN等网络。WN可以引入更小的噪音。

    但是WN要特别注意参数初始化的选择。


    【Layer normalization】
    更常见的比较是BN与LN的比较。
    BN层有两个缺点:

    1. 无法进行在线学习,因为在线学习的mini-batch为1;LN可以
    2. 之前提到的BN不能用在RNN中;LN可以
    3. 消耗一定的内存来记录均值和方差;LN不用

    但是,在CNN中LN并没有取得比BN更好的效果。

    参考链接:

    1. https://zhuanlan.zhihu.com/p/34879333
    2. https://www.zhihu.com/question/59728870
    3. https://zhuanlan.zhihu.com/p/113233908
    4. https://www.zhihu.com/question/55890057/answer/267872896



  • 相关阅读:
    一个程序员的职业规划
    基于Andoird 4.2.2的Account Manager源代码分析学习:创建选定类型的系统帐号
    [置顶] C++学习书单
    js快速分享代码
    The declared package does not match the expected package
    IBM Rational Appscan Part 1
    IBM Rational Appscan: Part 2 ---reference
    阅读redis源代码的一些体会
    18 Command Line Tools to Monitor Linux Performance
    代码规范
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13358410.html
Copyright © 2011-2022 走看看