zoukankan      html  css  js  c++  java
  • Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导

    Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导

    一、BN前向传播

    根据论文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推导过程,主要有下面四个公式:

    [mu_B=frac{1}{m}sum_i^mx_i ag{1}label{1} ]

    [delta_B^2=frac{1}{m}sum_i^m(x_i-mu_B)^2 ag{2}label{2} ]

    [widehat{x_i}=frac{x_i-mu_B}{sqrt{delta_B^2+epsilon}} ag{3}label{3} ]

    [y_i=gammawidehat{x_i}+eta ag{4}label{4} ]

    以MLP为例,假设输入的mini-batch样本数为(m),则此处的(x_i,i=1,2,...m)是第(i)个样本对应的某一层激活值中的一个激活值。也就是说,假设输入(m)个样本作为一次训练,其中第(i)个样本输入网络后,在(l)层得到了(N)个激活单元,则(x_i)代表其中任意一个激活单元。事实上应该写为(x_i^l(n))更为直观。

    所以BN实际上就是对第(l)层的第(n)个激活单元(x_i^l(n))求其在一个batch中的平均值和方差,并对其进行标准归一化,得到(widehat{x_i^l(n)}),可知归一化后的m个激活单元均值为0方差为1,一定程度上消除了Internal Covariate Shift,减少了网络的各层激活值在训练样本上的边缘分布的变化。

    二、BN的反向传播

    • 设前一层的梯度为(frac{partial{L}}{partial{y_i}}).
    • 需要计算(frac{partial{L}}{partial{x_i}},frac{partial{L}}{partial{gamma}}以及frac{partial{L}}{partial{eta}})

    由链式法则以及公式eqref{4}:

    [frac{partial{L}}{partial{gamma}}=frac{partial{L}}{partial{y_i}}frac{partial{y_i}}{partial{gamma}}=frac{partial{L}}{partial{y_i}}widehat{x_i} ag{5} ]

    由于对于所有(i=1,2...m. frac{partial{L}}{partial{y_i}}widehat{x_i}对frac{partial{L}}{partial{gamma}})均有贡献,因此一个batch的训练中将(frac{partial{L}}{partial{gamma}})定义为:

    [frac{partial{L}}{partial{gamma}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}}widehat{x_i} ag{6}label{6} ]

    同样有:

    [frac{partial{L}}{partial{eta}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}} ag{7}label{7} ]

    另外,求(frac{partial{L}}{partial{x_i}})过程则较为复杂。根据链式法则,以及公式(eqref{3}),将(widehat{x_i})视为(g(x_i,delta_B^2,mu_B))有:

    [frac{partial{L}}{partial{x_i}}=frac{partial{L}}{partial{y_i}}frac{partial{y_i}}{partial{widehat{x_i}}}(frac{partial{widehat{x_i}}}{partial{x_i}}+frac{partial{widehat{x_i}}}{partial{delta_B^2}}frac{partial{delta_B^2}}{partial{x_i}}+frac{partial{widehat{x_i}}}{partial{mu_B}}frac{partial{mu_B}}{partial{x_i}}) =frac{partial{L}}{partial{y_i}}frac{partial{y_i}}{partial{widehat{x_i}}}(g_1'+g_2'frac{partial{delta_B^2}}{partial{x_i}}+g_3'frac{partial{mu_B}}{partial{x_i}}) ag{8}label{8} ]

    而因为公式(eqref{2})可知上式括号中的第二项求偏导可以进一步拆分。(将(delta_B^2)视为(f(x_i,mu_B))

    [frac{partial{delta_B^2}}{partial{x_i}}= frac{partial{delta_B^2}}{partial{x_i}}+ frac{partial{delta_B^2}}{partial{mu_B}} frac{partial{mu_B}}{partial{x_i}}= f_1'+f_2'frac{partial{mu_B}}{partial{x_i}} ag{9}label{9} ]

    注意公式(eqref{9})中的两个(frac{partial{delta_B^2}}{partial{x_i}})代表不同的含义。由公式(eqref{8},eqref{9})可知,只要求出(f_1',f_2',g_1',g_2',g_3',frac{partial{mu_B}}{partial{x_i}},frac{partial{y_i}}{partial{widehat{x_i}}}).即可求出(frac{partial{L}}{partial{x_i}}).

    原论文中将公式(eqref{8})拆分成如下几项:

    [frac{partial{L}}{partial{x_i}}= frac{partial{L}}{partial{widehat{x_i}}} frac{partial{widehat{x_i}}}{partial{x_i}}+ frac{partial{L}}{partial{delta_B^2}} frac{partial{delta_B^2}}{partial{x_i}}+ frac{partial{L}}{partial{mu_B}} frac{partial{mu_B}}{partial{x_i}} ag{10}label{10} ]

    其中:

    [frac{partial{L}}{partial{widehat{x_i}}}= frac{partial{L}}{partial{y_i}} frac{partial{y_i}}{partial{widehat{x_i}}}= frac{partial{L}}{partial{y_i}} gamma ag{10.1}label{10.1} ]

    [frac{partial{widehat{x_i}}}{partial{x_i}}=g'_1=frac{1}{sqrt{delta_B^2+epsilon}} ag{10.2}label{10.2} ]

    [frac{partial{L}}{partial{delta_B^2}}= frac{partial{L}}{partial{widehat{x_i}}}g'_2= frac{partial{L}}{partial{widehat{x_i}}} frac{mu_B-x_i}{2}(delta_B^2+epsilon)^{-frac{3}{2}} longrightarrow ]

    [sum_{i=1}^mfrac{partial{L}}{partial{widehat{x_i}}} frac{mu_B-x_i}{2}(delta_B^2+epsilon)^{-frac{3}{2}} ag{10.3}label{10.3} ]

    [frac{partial{delta_B^2}}{partial{x_i}}=f'_1=frac{2(x_i-mu_B)}{m} ag{10.4}label{10.4} ]

    [frac{partial{L}}{partial{mu_B}}= frac{partial{L}}{partial{widehat{x_i}}}g'_3+ frac{partial{L}}{partial{widehat{x_i}}}g'_2f'_2 longrightarrow ]

    [sum_{i=1}^m( frac{partial{L}}{partial{widehat{x_i}}}frac{-1}{sqrt{delta_B^2+epsilon}} +frac{partial{L}}{partial{delta_B^2}}frac{2(mu_B-x_i)}{m}) ag{10.5}label{10.5} ]

    [frac{partial{mu_B}}{partial{x_i}}=frac{1}{m} ag{10.6}label{10.6} ]

    最终BN的反向过程由公式(eqref{6},eqref{7},eqref{10})给出。

    三、Batch Renormalization

    参照论文—— Batch Renormalization: Towards Reducing Minibatch Dependence
    in Batch-Normalized Models

    Batch Renormalization是对传统BN的优化,该方法保证了train和inference阶段的等效性,解决了非独立同分布和小minibatch的问题。

    1、前向

    跟原来的公式类似,添加了两个非训练参数(r,d):

    [mu_B=frac{1}{m}sum_i^mx_i ag{1.1}label{1.1} ]

    [sigma_B=sqrt{epsilon+frac{1}{m}sum_i^m(x_i-mu_B)^2} ag{2.1}label{2.1} ]

    [widehat{x_i}=frac{x_i-mu_B}{sigma_B}r+d ag{3.1}label{3.1} ]

    [y_i=gammawidehat{x_i}+eta ag{4.1}label{4.1} ]

    [r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(frac{sigma_B}{sigma})) ag{5.1}label{5.1} ]

    [d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(frac{mu_B-mu}{sigma})) ag{6.1}label{6.1} ]


    Update moving averages:

    [mu:=mu+alpha(mu_B-mu) ag{7.1}label{7.1} ]

    [sigma:=sigma+alpha(sigma_B-sigma) ag{8.1}label{8.1} ]

    Inference:

    [y=gammafrac{x-mu}{sigma}+eta ag{9.1}label{9.1} ]

    相比于之前的BN只在训练时计算滑动均值与方差,推断时才使用他们;BRN在训练和推断时都用到了滑动均值与方差。

    2、反向

    反向的推导与BN类似,

    [frac{partial{L}}{partial{widehat{x_i}}}= frac{partial{L}}{partial{y_i}} frac{partial{y_i}}{partial{widehat{x_i}}}= frac{partial{L}}{partial{y_i}} gamma ag{10.11}label{10.11} ]

    [frac{partial{L}}{partial{sigma_B}} longrightarrowsum_{i=1}^m frac{partial{L}}{partial{widehat{x_i}}} frac{-r(x_i-mu_B)}{sigma_B^2} ag{10.22}label{10.22} ]

    [frac{partial{L}}{partial{mu_B}}longrightarrowsum_{i=1}^{m}frac{partial{L}}{partial{widehat{x_i}}}frac{-r}{sigma_B} ag{10.33}label{10.33} ]

    [frac{partial{L}}{partial{x_i}}= frac{partial{L}}{partial{widehat{x_i}}} frac{r}{sigma_B}+ frac{partial{L}}{partial{sigma_B}} frac{x_i-mu_B}{msigma_B}+ frac{partial{L}}{partial{mu_B}} frac{1}{m} ag{10.44}label{10.44} ]

    [frac{partial{L}}{partial{gamma}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}}widehat{x_i} ag{10.55}label{10.55} ]

    [frac{partial{L}}{partial{eta}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}} ag{10.66}label{10.66} ]

    三、卷积网络中的BN

    ​ 上面的推导过程都是基于MLP的。对于卷积网络而言,BN过程中的m个激活单元被推广为m幅特征图像。 假设某一层卷积后的feature map是([N,H,W,C])的张量,其中N表示batch数目,H,W分别表示长和宽,C表示特征通道数。则对卷积网络的BN操作时,令(m = N imes H imes W),也就是说将第(i)个batch内某一通道(c)上的任意一个特征图像素点视为(x_i),套用上面的BN公式即可。所以对于卷积网络来说,中间激活层每个通道都对应一组BN参数(gamma,eta).

  • 相关阅读:
    lua 中的上n级模块路径函数分享
    [poj 1062] 昂贵的聘礼
    [poj 2479] Maximum sum -- 转载
    IT界天才少年:比肩雷军、叫板任正非,自己作死了
    chromedriver版本 支持的Chrome版本
    运维开发:python websocket网页实时显示远程服务器日志信息
    JVM理论:(三/4)方法调用
    JVM理论:(三/3)运行时栈帧结构、基于栈的字节码解释执行过程
    JVM理论:(三/2)字节码指令
    JVM理论:(三/1)class类文件结构
  • 原文地址:https://www.cnblogs.com/lyc-seu/p/12676505.html
Copyright © 2011-2022 走看看