zoukankan      html  css  js  c++  java
  • 《神经网络的梯度推导与代码验证》之vanilla RNN的前向传播和反向梯度推导

    在本篇章,我们将专门针对vanilla RNN,也就是所谓的原始RNN这种网络结构进行前向传播介绍和反向梯度推导。更多相关内容请见《神经网络的梯度推导与代码验证》系列介绍

     

    注意:


     目录

    提醒:

    • 后续会反复出现$oldsymbol{delta}^{l}$这个(类)符号,它的定义为$oldsymbol{delta}^{l} = frac{partial l}{partialoldsymbol{z}^{oldsymbol{l}}}$,即loss $l$对$oldsymbol{z}^{oldsymbol{l}}$的导数
    • 其中$oldsymbol{z}^{oldsymbol{l}}$表示第$l$层(DNN,CNN,RNN或其他例如max pooling层等)未经过激活函数的输出。
    • $oldsymbol{a}^{oldsymbol{l}}$则表示$oldsymbol{z}^{oldsymbol{l}}$经过激活函数后的输出。

    这些符号会贯穿整个系列,还请留意。


     

    4.1 vanilla RNN的前向传播

    先贴一张vanilla(朴素)RNN的前传示意图。

    上图中左边是RNN模型没有按时间展开的图,如果按时间序列展开,则是上图中的右边部分。我们重点观察右边部分的图。这幅图描述了在序列索引号t附近RNN的模型。其中:

    • $oldsymbol{x}^{(t)}$代表在序列索引号$t$时训练样本的输入。注意这里的$t$只是代表序列索引,不一定非得具备时间上的含义,例如$oldsymbol{x}^{(t)}$可以是某句子的第$t$个字(的词向量)。
    • $oldsymbol{h}^{(t)}$代表在序列索引号$t$时模型的隐藏状态。$oldsymbol{h}^{(t)}$由$oldsymbol{x}^{(t)}$和$oldsymbol{h}^{(t-1)}$共同决定
    • $oldsymbol{a}^{(t)}$代表在序列索引号$t$时模型的输出。$oldsymbol{o}^{(t)}$只由模型当前的隐藏状态$oldsymbol{h}^{(t-1)}$决定
    • $oldsymbol{L}^{(t)}$代表在序列索引号$t$时模型的损失函数。
    • $oldsymbol{y}^{(t)}$代表在序列索引号$t$时训练样本序列的真实输出
    • $oldsymbol{U},oldsymbol{W},oldsymbol{V}$三个矩阵式我们模型的线性相关系数,它们在整个vanilla RNN网络中共享的,这点和DNN很不同。也正因为是共享的,它体现了RNN模型的“循环/递归”的核心思想。

    4.1.1 RNN前向传播计算公式

    有了上面的模型,RNN的前向传播算法就很容易得到了。

     

    对于任意一个序列索引号$t$,我们隐藏状态$oldsymbol{h}^{(t)}$由$oldsymbol{x}^{(t)}$和$oldsymbol{h}^{(t-1)}$共同得到:

    $oldsymbol{h}^{(t)} = sigmaleft( oldsymbol{z}^{(t)} ight) = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t)} + oldsymbol{W}oldsymbol{h}^{(t - 1)} + oldsymbol{b}} ight)$

    其中$sigma$为RNN的激活函数,一般为$tanh$。

    序列索引号为$t$时,模型的输出$oldsymbol{o}^{(t)}$的表达式也比较简单:

    $oldsymbol{o}^{(t)} = oldsymbol{V}oldsymbol{h}^{(t - 1)} + oldsymbol{c}$

     

    在最终在序列索引号时我们的预测输出为:

    ${hat{oldsymbol{y}}}^{(t)} = sigmaleft( oldsymbol{o}^{(t)} ight)$

     

    对比下列公式:

    $oldsymbol{h}^{(t)} = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t)} + oldsymbol{W}oldsymbol{h}^{(t - 1)} + oldsymbol{b}} ight)$

    $oldsymbol{a}^{l} = sigmaleft( {oldsymbol{W}^{l}oldsymbol{a}^{l - 1} + oldsymbol{b}^{l}} ight)$

     

    上面的是vanilla RNN的$oldsymbol{h}^{(t)}$的递推公式,而下面的是DNN中的层间关系的公式。我们可以发现这两组公式在形式上非常接近。如果将$oldsymbol{h}^{(t)}$的这种时间上的展开看成类似于DNN这种层间堆叠的话,可以发现vanilla RNN每一“层”除了有来自上一“层”的输入$oldsymbol{h}^{(t - 1)}$,还有专属于这一层的输入$oldsymbol{x}^{(t)}$,最重要的是,每一“层”的参数$oldsymbol{W}$和$oldsymbol{b}$都是同一组。而DNN则是有专属于那一层的$oldsymbol{W}^{l}$和$oldsymbol{b}^{l}$。


    4.2 vanilla RNN的反向梯度推导

    RNN反向传播算法的思路和DNN是一样的,即通过梯度下降法一轮轮的迭代,得到合适的RNN模型参数$oldsymbol{U},oldsymbol{W},oldsymbol{V},oldsymbol{b},oldsymbol{c}$。由于我们是基于时间反向传播,所以RNN的反向传播有时也叫做BPTT(back-propagation through time)。当然这里的BPTTDNN也有很大的不同点,即这里所有的$oldsymbol{U},oldsymbol{W},oldsymbol{V},oldsymbol{b},oldsymbol{c}$在序列的各个位置是共享的,反向传播时我们更新的是相同的参数。

     

    为了简化描述,这里的损失函数我们为交叉熵损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数。

     

    如果RNN在序列的每个位置有输出,则最终的损失L为所有时间步$t$的loss之和:

    $L = {sumlimits_{t = 1}^{T}L^{(t)}}$

    其中,$oldsymbol{V},oldsymbol{c}$的梯度计算比较简单,跟求DNNBP是一样的。

    根据 数学基础篇:矩阵微分与求导 1.8节例子的中间结果,我们可以知道:

    $frac{partial L}{partialoldsymbol{c}} = {sumlimits_{t = 1}^{T}frac{partial L^{(t)}}{partialoldsymbol{c}}} = {sumlimits_{t = 1}^{T}{{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}}}$

    $frac{partial L}{partialoldsymbol{V}} = {sumlimits_{t = 1}^{T}frac{partial L^{(t)}}{partialoldsymbol{V}}} = {sumlimits_{t = 1}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight)}left( oldsymbol{h}^{(t)} ight)^{T}$

     

    接下来的$oldsymbol{U},oldsymbol{W},oldsymbol{b}$的梯度计算就相对复杂了。从RNN的模型可以看出,在反向传播时,某一序列位置$t$的梯度由当前位置的输出对应的梯度和序列索引位置$t+1$时的梯度两部分共同决定。对于$oldsymbol{W}$在某一序列位置$t$的梯度损失需要反向传播一步一步地计算。我们定义序列索引$t$位置的隐藏状态的梯度为:

    $oldsymbol{delta}^{(t)} = frac{partial L}{partialoldsymbol{h}^{(t)}}$

     

    如果我们能知道$oldsymbol{delta}^{(t)}$,那么根据$oldsymbol{h}^{(t)} = sigmaleft( oldsymbol{z}^{(t)} ight) = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t)} + oldsymbol{W}oldsymbol{h}^{(t - 1)} + oldsymbol{b}} ight)$我们就像DNN那样套用标量对矩阵的链式求导法则来进一步得到$oldsymbol{U},oldsymbol{W},oldsymbol{b}$的梯度了。

     

    根据4.1节中的示意图我们可以轻易发现,当$t = T$,则误差只有$left. L^{(T)} ightarrowoldsymbol{h}^{(T)} ight.$这么一条。

    所以:

    $oldsymbol{delta}^{(T)} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(T)} - oldsymbol{y}^{(T)}} ight)$

     

    而当$t<T$时,$oldsymbol{h}^{(t)}$的误差来源有两条:

    1)$left. L^{(t)} ightarrowoldsymbol{h}^{(t)} ight.$

    2)$left. oldsymbol{h}^{({t + 1})} ightarrowoldsymbol{h}^{(t)} ight.$

     

    于是我们得到:

    $oldsymbol{delta}^{(t)} = frac{partial L^{(t)}}{partialoldsymbol{h}^{(t)}} + left( frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t + 1)}}$

     

    我们来逐项求解:

    首先对于$frac{partial L^{(t)}}{partialoldsymbol{h}^{(t)}}$:

    $oldsymbol{delta}^{(t)} = frac{partial L}{partialoldsymbol{h}^{(t)}} = left( frac{partialoldsymbol{o}^{(t)}}{partialoldsymbol{h}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{o}^{(t)}} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight)$

     

    对于$left( frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} ight)^{T}frac{partial L^{({t + 1})}}{partialoldsymbol{h}^{(t + 1)}}$,我们先关注$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$:

    因为$oldsymbol{h}^{(t + 1)} = sigmaleft( oldsymbol{z}^{(t)} ight) = sigmaleft( {oldsymbol{U}oldsymbol{x}^{(t + 1)} + oldsymbol{W}oldsymbol{h}^{(t)} + oldsymbol{b}} ight)$

    所以有:

     $doldsymbol{h}^{(t + 1)} = sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)igodot doldsymbol{z}^{(t)} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)doldsymbol{z}^{(t)} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)dleft( {oldsymbol{W}oldsymbol{h}^{(t)}} ight) = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)oldsymbol{W}doldsymbol{h}^{(t)}$

    所以有:$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)oldsymbol{W}$

    于是:

    $oldsymbol{delta}^{(t)} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight) + oldsymbol{W}^{T}diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t + 1)}$

     

    有了$oldsymbol{delta}^{(T)}$以及从$oldsymbol{delta}^{(t + 1)}$到$oldsymbol{delta}^{(t)}$的递推公式,我们可以轻易求出$oldsymbol{U},oldsymbol{W},oldsymbol{b}$的梯度,由于这三组变量在不同的$t$下是公用的,所以由全微分方程可知,这三个变量应当都是在$t$上的某种累加形式。我们定义只在时间步$t$使用的虚拟变量$oldsymbol{U}^{(t)},oldsymbol{W}^{(t)},oldsymbol{b}^{(t)}$,这样就可以用$frac{partial L}{partialoldsymbol{W}^{(t)}}$来表示$oldsymbol{W}$在时间步$t$的时候对梯度的贡献:

    $frac{partial L}{partialoldsymbol{W}} = {sumlimits_{t = 1}^{T}frac{partial L}{partialoldsymbol{W}^{(t)}}} = {sumlimits_{t = 1}^{T}{left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{W}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t)}} =}}{sumlimits_{t = 1}^{T}{diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t)}left( oldsymbol{h}^{(t - 1)} ight)^{T}}}$

     

    同理,我们得到:

    $frac{partial L}{partialoldsymbol{b}} = {sumlimits_{t = 1}^{T}{frac{partial L}{partialoldsymbol{b}^{(t)}} =}}{sumlimits_{t = 1}^{T}{left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{b}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t)}} = {sumlimits_{t = 1}^{T}{diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t)}}}}}$

    $frac{partial L}{partialoldsymbol{U}} = {sumlimits_{t = 1}^{T}{frac{partial L}{partialoldsymbol{U}^{(t)}} =}}{sumlimits_{t = 1}^{T}{left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{U}^{(t)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(t)}} = {sumlimits_{t = 1}^{T}{diagleft( {sigma^{'}left( oldsymbol{h}^{(t + 1)} ight)} ight)oldsymbol{delta}^{(t)}left( oldsymbol{x}^{(t)} ight)^{T}}}}}$

     


     

    4.3 RNN发生梯度消失与梯度爆炸的原因分析

    上一节我们得到了从$oldsymbol{h}^{(t + 1)}$到$oldsymbol{h}^{(t)}$的递推公式:

    $frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} = diagleft( {sigma^{'}left( oldsymbol{h}^{({t + 1})} ight)} ight)oldsymbol{W}$

     

    在求$oldsymbol{h}^{(t)}$的时候,我们需要从$oldsymbol{h}^{(T)}$开始根据上面这个公式一步一步推到$oldsymbol{h}^{(t)}$,可以想象$oldsymbol{W}$在这期间会被疯狂地连乘。当我们要求某个时间步$t$下的$frac{partial L}{partialoldsymbol{W}^{(t)}}$时,这一堆连乘的$oldsymbol{W}$也会被带上。结果就是(粗略地分析),如果$oldsymbol{W}$里的值都比较大,就会发生梯度爆炸,反之则发生梯度消失。

    如果本文对您有所帮助的话,不妨点下“推荐”让它能帮到更多的人,谢谢。


     参考资料

    • 书籍:《Deep Learning》(深度学习)
  • 相关阅读:
    带你剖析WebGis的世界奥秘----点和线的世界
    XML解析
    Java-工厂设计模式
    分享:软件包和文档
    启航,新开始
    docker容器网络通信原理分析(转)
    【好书分享】容器网络到kubernetes网络
    go语言接受者的选取
    go语言的unsafe包(转)
    protocol buffers生成go代码原理
  • 原文地址:https://www.cnblogs.com/sumwailiu/p/13614859.html
Copyright © 2011-2022 走看看