Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导
一、BN前向传播
根据论文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推导过程,主要有下面四个公式:
[mu_B=frac{1}{m}sum_i^mx_i ag{1}label{1}
]
[delta_B^2=frac{1}{m}sum_i^m(x_i-mu_B)^2 ag{2}label{2}
]
[widehat{x_i}=frac{x_i-mu_B}{sqrt{delta_B^2+epsilon}} ag{3}label{3}
]
[y_i=gammawidehat{x_i}+eta ag{4}label{4}
]
以MLP为例,假设输入的mini-batch样本数为(m),则此处的(x_i,i=1,2,...m)是第(i)个样本对应的某一层激活值中的一个激活值。也就是说,假设输入(m)个样本作为一次训练,其中第(i)个样本输入网络后,在(l)层得到了(N)个激活单元,则(x_i)代表其中任意一个激活单元。事实上应该写为(x_i^l(n))更为直观。
所以BN实际上就是对第(l)层的第(n)个激活单元(x_i^l(n))求其在一个batch中的平均值和方差,并对其进行标准归一化,得到(widehat{x_i^l(n)}),可知归一化后的m个激活单元均值为0方差为1,一定程度上消除了Internal Covariate Shift,减少了网络的各层激活值在训练样本上的边缘分布的变化。
二、BN的反向传播
- 设前一层的梯度为(frac{partial{L}}{partial{y_i}}).
- 需要计算(frac{partial{L}}{partial{x_i}},frac{partial{L}}{partial{gamma}}以及frac{partial{L}}{partial{eta}})
由链式法则以及公式eqref{4}:
[frac{partial{L}}{partial{gamma}}=frac{partial{L}}{partial{y_i}}frac{partial{y_i}}{partial{gamma}}=frac{partial{L}}{partial{y_i}}widehat{x_i} ag{5}
]
由于对于所有(i=1,2...m. frac{partial{L}}{partial{y_i}}widehat{x_i}对frac{partial{L}}{partial{gamma}})均有贡献,因此一个batch的训练中将(frac{partial{L}}{partial{gamma}})定义为:
[frac{partial{L}}{partial{gamma}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}}widehat{x_i} ag{6}label{6}
]
同样有:
[frac{partial{L}}{partial{eta}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}} ag{7}label{7}
]
另外,求(frac{partial{L}}{partial{x_i}})过程则较为复杂。根据链式法则,以及公式(eqref{3}),将(widehat{x_i})视为(g(x_i,delta_B^2,mu_B))有:
[frac{partial{L}}{partial{x_i}}=frac{partial{L}}{partial{y_i}}frac{partial{y_i}}{partial{widehat{x_i}}}(frac{partial{widehat{x_i}}}{partial{x_i}}+frac{partial{widehat{x_i}}}{partial{delta_B^2}}frac{partial{delta_B^2}}{partial{x_i}}+frac{partial{widehat{x_i}}}{partial{mu_B}}frac{partial{mu_B}}{partial{x_i}})
=frac{partial{L}}{partial{y_i}}frac{partial{y_i}}{partial{widehat{x_i}}}(g_1'+g_2'frac{partial{delta_B^2}}{partial{x_i}}+g_3'frac{partial{mu_B}}{partial{x_i}})
ag{8}label{8}
]
而因为公式(eqref{2})可知上式括号中的第二项求偏导可以进一步拆分。(将(delta_B^2)视为(f(x_i,mu_B)))
[frac{partial{delta_B^2}}{partial{x_i}}=
frac{partial{delta_B^2}}{partial{x_i}}+
frac{partial{delta_B^2}}{partial{mu_B}}
frac{partial{mu_B}}{partial{x_i}}=
f_1'+f_2'frac{partial{mu_B}}{partial{x_i}}
ag{9}label{9}
]
注意公式(eqref{9})中的两个(frac{partial{delta_B^2}}{partial{x_i}})代表不同的含义。由公式(eqref{8},eqref{9})可知,只要求出(f_1',f_2',g_1',g_2',g_3',frac{partial{mu_B}}{partial{x_i}},frac{partial{y_i}}{partial{widehat{x_i}}}).即可求出(frac{partial{L}}{partial{x_i}}).
原论文中将公式(eqref{8})拆分成如下几项:
[frac{partial{L}}{partial{x_i}}=
frac{partial{L}}{partial{widehat{x_i}}}
frac{partial{widehat{x_i}}}{partial{x_i}}+
frac{partial{L}}{partial{delta_B^2}}
frac{partial{delta_B^2}}{partial{x_i}}+
frac{partial{L}}{partial{mu_B}}
frac{partial{mu_B}}{partial{x_i}}
ag{10}label{10}
]
其中:
[frac{partial{L}}{partial{widehat{x_i}}}=
frac{partial{L}}{partial{y_i}}
frac{partial{y_i}}{partial{widehat{x_i}}}=
frac{partial{L}}{partial{y_i}}
gamma ag{10.1}label{10.1}
]
[frac{partial{widehat{x_i}}}{partial{x_i}}=g'_1=frac{1}{sqrt{delta_B^2+epsilon}}
ag{10.2}label{10.2}
]
[frac{partial{L}}{partial{delta_B^2}}=
frac{partial{L}}{partial{widehat{x_i}}}g'_2=
frac{partial{L}}{partial{widehat{x_i}}}
frac{mu_B-x_i}{2}(delta_B^2+epsilon)^{-frac{3}{2}}
longrightarrow
]
[sum_{i=1}^mfrac{partial{L}}{partial{widehat{x_i}}}
frac{mu_B-x_i}{2}(delta_B^2+epsilon)^{-frac{3}{2}}
ag{10.3}label{10.3}
]
[frac{partial{delta_B^2}}{partial{x_i}}=f'_1=frac{2(x_i-mu_B)}{m}
ag{10.4}label{10.4}
]
[frac{partial{L}}{partial{mu_B}}=
frac{partial{L}}{partial{widehat{x_i}}}g'_3+
frac{partial{L}}{partial{widehat{x_i}}}g'_2f'_2
longrightarrow
]
[sum_{i=1}^m(
frac{partial{L}}{partial{widehat{x_i}}}frac{-1}{sqrt{delta_B^2+epsilon}}
+frac{partial{L}}{partial{delta_B^2}}frac{2(mu_B-x_i)}{m})
ag{10.5}label{10.5}
]
[frac{partial{mu_B}}{partial{x_i}}=frac{1}{m}
ag{10.6}label{10.6}
]
最终BN的反向过程由公式(eqref{6},eqref{7},eqref{10})给出。
三、Batch Renormalization
参照论文—— Batch Renormalization: Towards Reducing Minibatch Dependence
in Batch-Normalized Models
Batch Renormalization是对传统BN的优化,该方法保证了train和inference阶段的等效性,解决了非独立同分布和小minibatch的问题。
1、前向
跟原来的公式类似,添加了两个非训练参数(r,d):
[mu_B=frac{1}{m}sum_i^mx_i ag{1.1}label{1.1}
]
[sigma_B=sqrt{epsilon+frac{1}{m}sum_i^m(x_i-mu_B)^2} ag{2.1}label{2.1}
]
[widehat{x_i}=frac{x_i-mu_B}{sigma_B}r+d ag{3.1}label{3.1}
]
[y_i=gammawidehat{x_i}+eta ag{4.1}label{4.1}
]
[r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(frac{sigma_B}{sigma})) ag{5.1}label{5.1}
]
[d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(frac{mu_B-mu}{sigma})) ag{6.1}label{6.1}
]
Update moving averages:
[mu:=mu+alpha(mu_B-mu) ag{7.1}label{7.1}
]
[sigma:=sigma+alpha(sigma_B-sigma) ag{8.1}label{8.1}
]
Inference:
[y=gammafrac{x-mu}{sigma}+eta ag{9.1}label{9.1}
]
相比于之前的BN只在训练时计算滑动均值与方差,推断时才使用他们;BRN在训练和推断时都用到了滑动均值与方差。
2、反向
反向的推导与BN类似,
[frac{partial{L}}{partial{widehat{x_i}}}=
frac{partial{L}}{partial{y_i}}
frac{partial{y_i}}{partial{widehat{x_i}}}=
frac{partial{L}}{partial{y_i}}
gamma ag{10.11}label{10.11}
]
[frac{partial{L}}{partial{sigma_B}}
longrightarrowsum_{i=1}^m
frac{partial{L}}{partial{widehat{x_i}}}
frac{-r(x_i-mu_B)}{sigma_B^2}
ag{10.22}label{10.22}
]
[frac{partial{L}}{partial{mu_B}}longrightarrowsum_{i=1}^{m}frac{partial{L}}{partial{widehat{x_i}}}frac{-r}{sigma_B}
ag{10.33}label{10.33}
]
[frac{partial{L}}{partial{x_i}}=
frac{partial{L}}{partial{widehat{x_i}}}
frac{r}{sigma_B}+
frac{partial{L}}{partial{sigma_B}}
frac{x_i-mu_B}{msigma_B}+
frac{partial{L}}{partial{mu_B}}
frac{1}{m}
ag{10.44}label{10.44}
]
[frac{partial{L}}{partial{gamma}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}}widehat{x_i} ag{10.55}label{10.55}
]
[frac{partial{L}}{partial{eta}}=sum_{i=1}^m frac{partial{L}}{partial{y_i}} ag{10.66}label{10.66}
]
三、卷积网络中的BN
上面的推导过程都是基于MLP的。对于卷积网络而言,BN过程中的m个激活单元被推广为m幅特征图像。 假设某一层卷积后的feature map是([N,H,W,C])的张量,其中N表示batch数目,H,W分别表示长和宽,C表示特征通道数。则对卷积网络的BN操作时,令(m = N imes H imes W),也就是说将第(i)个batch内某一通道(c)上的任意一个特征图像素点视为(x_i),套用上面的BN公式即可。所以对于卷积网络来说,中间激活层每个通道都对应一组BN参数(gamma,eta).