前言
在本篇章,我们将专门针对LSTM这种网络结构进行前向传播介绍和反向梯度推导。
关于LSTM的梯度推导,这一块确实挺不好掌握,原因有:
- 一些经典的deep learning 教程,例如花书缺乏相关的内容
- 一些经典的论文不太好看懂,例如On the difficulty of training Recurrent Neural Networks上有LSTM的梯度推导但看得我还是一头雾水(可能是我能力有限。。)
- 网上关于LSTM的梯度推导虽多,但缺少保证其正确性的验证实验
考虑到上述问题,本篇章将以最低限度的知识依赖进行LSTM的反向梯度推导,所有推导基础均基于《神经网络的梯度推导与代码验证》之数学基础篇:矩阵微分与求导。为保证所得无误,后续将通过tensorflow的自动微分工具验证LSTM梯度推导结论的准确性。另外,为节约体能,推导过程相对没那么详细,因为这实际上只是反复应用矩阵微分与求导的原理而已,手把手教学级的内容请参考数学基础篇:矩阵微分与求导。
更多相关内容请见《神经网络的梯度推导与代码验证》系列介绍。
目录
- 5.1 LSTM的前向传播
- 5.1.2 LSTM的遗忘门
- 5.1.3 LSTM的输入门
- 5.1.4 LSTM的C(cell)状态更新
- 5.1.5 LSTM的输出门
- 5.1.6 LSTM的前向传播总结
- 5.2 LSTM的反向梯度推导
- 5.3 LSTM能改善梯度消失的原因
- 参考资料
提醒:
- 后续会反复出现$oldsymbol{delta}^{l}$这个(类)符号,它的定义为$oldsymbol{delta}^{l} = frac{partial l}{partialoldsymbol{z}^{oldsymbol{l}}}$,即loss $l$对$oldsymbol{z}^{oldsymbol{l}}$的导数
- 其中$oldsymbol{z}^{oldsymbol{l}}$表示第$l$层(DNN,CNN,RNN或其他例如max pooling层等)未经过激活函数的输出。
- $oldsymbol{a}^{oldsymbol{l}}$则表示$oldsymbol{z}^{oldsymbol{l}}$经过激活函数后的输出。
这些符号会贯穿整个系列,还请留意。
5.1 LSTM的前向传播
在RNN模型里,我们讲到了RNN具有如下的结构,每个序列索引位置$t$都有一个隐藏状态$oldsymbol{h}^{(t)}$。
如果我们只关注RNN的核心循环部分而不看$oldsymbol{o}^{(t)}$,$oldsymbol{L}^{(t)}$和$oldsymbol{y}^{(t)}$,则RNN的模型可以简化成如下图的形式:
图中可以很清晰看出在隐藏状态$oldsymbol{h}^{(t)}$由$oldsymbol{x}^{(t)}$和$oldsymbol{h}^{(t-1)}$共同得到。得到的$oldsymbol{h}^{(t)}$方面用于当前层的模型损失计算,另一方面用于计算下一层的$oldsymbol{h}^{(t+1)}$。
由于RNN梯度消失的问题,大牛们对于序列索引位置t的隐藏结构做了改进,可以说通过一些技巧让隐藏结构复杂了起来,来避免梯度消失的问题,这样的特殊RNN就是我们的LSTM。由于LSTM有很多的变种,这里我们以最常见的LSTM为例讲述。LSTM的结构如下图:
5.1.1 LSTM之细胞状态
上面我们给出了LSTM的模型结构,下面我们就一点点的剖析LSTM模型在每个序列索引位置$t$时刻的内部结构。
从上图中可以看出,在每个序列索引位置$t$时刻向前传播的除了和RNN一样的隐藏状态$oldsymbol{h}^{(t+1)}$,还多了另一个隐藏状态,如图中上面的长横线。这个隐藏状态我们一般称为细胞状态(Cell State),记为$oldsymbol{C}^{(t)}$。如下图所示:
我们可以看到从$oldsymbol{C}^{(t - 1)}$到$oldsymbol{C}^{(t)}$,似乎经过了若干乘法和加法操作。
除了细胞状态,LSTM图中还有了很多奇怪的结构,这些结构一般称之为门控结构(Gate)。LSTM在在每个序列索引位置t的门一般包括遗忘门,输入门和输出门三种。下面我们就来研究上图中LSTM的遗忘门,输入门和输出门以及细胞状态。
5.1.2 LSTM之遗忘门
遗忘门(forget gate)顾名思义,是控制是否遗忘的,在LSTM中即以一定的概率控制是否遗忘上一层的隐藏细胞状态。遗忘门子结构如下图所示:
图中输入的有上一序列的隐藏状态$oldsymbol{h}^{(t - 1)}$和$t$时刻的输入$oldsymbol{x}^{(t - 1)}$,通过一个激活函数(一般是sigmoid),得到遗忘门的输出$oldsymbol{f}^{(t)}$:
$oldsymbol{f}^{(t)} = sigmaleft( {oldsymbol{W}_{f}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{f}oldsymbol{x}^{(t - 1)} + oldsymbol{b}_{f}} ight)$
由于sigmoid的值域介于0~1之间,所以这里的$oldsymbol{f}^{(t)}$表示保留上一个时间步$oldsymbol{h}^{(t - 1)}$的多大的成分。虽然“保留”跟“遗忘门”这两个词是概念上相反的,但大家似乎已经习惯用遗忘门来称呼这个$oldsymbol{f}^{(t)}$了。
5.1.3 LSTM之输入门
输入门(input gate)负责管理当前序列位置的输入,它的子结构如下图:
输入门$oldsymbol{i}^{(t)}$的数学表达式为:
$oldsymbol{i}^{(t)} = sigmaleft( {oldsymbol{W}_{i}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{i}oldsymbol{x}^{(t - 1)} + oldsymbol{b}_{i}} ight)$
对比遗忘门的表达式,除了矩阵的下标发生了点改变以外,其他都一样。
而遗忘门的控制对象则是$oldsymbol{h}^{(t - 1)}$和$oldsymbol{x}^{(t - 1)}$组合的产物,它的表达式如下:
$oldsymbol{a}^{(t)} = sigmaleft( {oldsymbol{W}_{a}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{a}oldsymbol{x}^{(t - 1)} + oldsymbol{b}_{a}} ight)$
5.1.4 LSTM之细胞状态更新
在研究LSTM输出门之前,我们要先看看LSTM之细胞状态。前面的遗忘门和输入门的结果都会作用于细胞状态$oldsymbol{C}^{(t)}$。我们来看看$oldsymbol{C}^{(t - 1)}$是如何得到$oldsymbol{C}^{(t)}$的:
细胞状态$oldsymbol{C}^{(t)}$由两部分组成,第一部分是$oldsymbol{C}^{(t - 1)}$和遗忘门$oldsymbol{f}^{(t)}$的Hadamard积(逐元素相乘),第二部分是$oldsymbol{a}^{(t)}$和输入门$oldsymbol{i}^{(t)}$的Hadamard积:
$oldsymbol{C}^{(t)} = oldsymbol{C}^{(t)}igodotoldsymbol{f}^{(t)} + oldsymbol{a}^{(t)}igodotoldsymbol{i}^{(t)}$
5.1.5 LSTM之输出门
有了新的隐藏细胞状态$oldsymbol{C}^{(t)}$,现在来到输出门:
输出门$oldsymbol{o}^{(t)}$的数学表达式为:
$oldsymbol{o}^{(t)} = sigmaleft( {oldsymbol{W}_{o}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{o}oldsymbol{x}^{(t - 1)} + oldsymbol{b}_{o}} ight)$
而输出门所控制的对象,则是$tanhleft( oldsymbol{C}^{(t)} ight)$,两者共同形成$t$时间步下的隐藏状态$oldsymbol{h}^{(t)}$:
$oldsymbol{h}^{(t)} = oldsymbol{o}^{(t)}igodot tanhleft( oldsymbol{C}^{(t)} ight)$
5.1.6 LSTM前向传播总结
现在我们来总结下LSTM前向传播算法。LSTM模型有两个隐藏状态$oldsymbol{h}^{(t)}$,$oldsymbol{C}^{(t)}$,模型参数恰好是RNN的4倍整。
前向传播过程在每个时间步$t$上发生的顺序为:
1)更新遗忘门输出:
$oldsymbol{f}^{(t)} = sigmaleft( {oldsymbol{W}_{f}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{f}oldsymbol{x}^{(t)} + oldsymbol{b}_{f}} ight)$
2)更新输入门和其控制对象:
$oldsymbol{i}^{(t)} = sigmaleft( {oldsymbol{W}_{i}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{i}oldsymbol{x}^{(t)} + oldsymbol{b}_{i}} ight)$
$oldsymbol{a}^{(t)} = tanhleft( {oldsymbol{W}_{a}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{a}oldsymbol{x}^{(t)} + oldsymbol{b}_{a}} ight)$
3)更新细胞状态,从而$left. oldsymbol{C}^{(t - 1)}longrightarrowoldsymbol{C}^{(t)} ight.$:
$oldsymbol{C}^{(t)} = oldsymbol{C}^{(t - 1)}igodotoldsymbol{f}^{(t)} + oldsymbol{a}^{(t)}igodotoldsymbol{i}^{(t)}$
4)更新输出门和其控制对象,从而$left. oldsymbol{h}^{(t - 1)}longrightarrowoldsymbol{h}^{(t)} ight.$:
$oldsymbol{o}^{(t)} = sigmaleft( {oldsymbol{W}_{o}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{o}oldsymbol{x}^{(t - 1)} + oldsymbol{b}_{o}} ight)$
$oldsymbol{h}^{(t)} = oldsymbol{o}^{(t)}igodot tanhleft( oldsymbol{C}^{(t)} ight)$
5)得到当前时间步$t$的预测输出:
${hat{oldsymbol{y}}}^{(t)} = sigmaleft( {oldsymbol{V}oldsymbol{h}^{(t)} + oldsymbol{c}} ight)$
5.2 LSTM的反向梯度推导
在RNN中,为了计算反向传播误差,我们通过隐藏状态$oldsymbol{h}^{(t)}$的梯度$oldsymbol{delta}^{(t)}$一步一步向前传播。在LSTM中也类似,只不过我们这里由两种隐藏状态$oldsymbol{h}^{(t)}$和$oldsymbol{C}^{(t)}$,这里我们定义两种$oldsymbol{delta}$:
$oldsymbol{delta}_{h}^{(t)} = frac{partial L}{partialoldsymbol{h}^{(t)}}$
$oldsymbol{delta}_{C}^{(t)} = frac{partial L}{partialoldsymbol{C}^{(t)}}$
为了方便找到梯度的递推模式,下面是根据前向传播公式给出数据在LSTM中数据的前向流动示意图:
对于$t = T$,即时间序列截止的那个时间步,我们可以得到:
$oldsymbol{delta}_{h}^{(T)} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(T)} - oldsymbol{y}^{(T)}} ight)$
$oldsymbol{delta}_{C}^{(T)} = left( frac{partialoldsymbol{h}^{(T)}}{partialoldsymbol{C}^{(T)}} ight)^{T}frac{partial L}{partialoldsymbol{h}^{(T)}} = oldsymbol{delta}_{h}^{(T)}igodotoldsymbol{o}^{(T)}igodot{tanh}^{'}left( oldsymbol{C}^{(T)} ight)$
第一个式子的证明见vanilla RNN的前向传播和反向梯度推导 的4.2节;第二个式子根据等式$oldsymbol{h}^{(t)} = oldsymbol{o}^{(t)}igodot tanhleft( oldsymbol{C}^{(t)} ight)$结合数学基础篇:矩阵微分与求导的理论即可秒证出来。
对于$t < T$时,我们要利用$oldsymbol{delta}_{h}^{(t + 1)}$和$oldsymbol{delta}_{C}^{(t + 1)}$递推得到$oldsymbol{delta}_{h}^{(t)}$和$oldsymbol{delta}_{C}^{(t)}$。
先来推导$oldsymbol{delta}_{h}^{(t)}$的递推公式:
根据上图我们知道,$oldsymbol{delta}_{h}^{(t)}$的误差来源如下:
1)$left. lleft( t ight)longrightarrowoldsymbol{h}^{(t)} ight.$
2)$left. oldsymbol{h}^{(t + 1)}longrightarrowoldsymbol{o}^{(t + 1)}longrightarrowoldsymbol{h}^{(t)} ight.$
3)$left. oldsymbol{C}^{(t + 1)}longrightarrowoldsymbol{i}^{(t + 1)}longrightarrowoldsymbol{h}^{(t)} ight.$
4)$left. oldsymbol{C}^{(t + 1)}longrightarrowoldsymbol{a}^{(t + 1)}longrightarrowoldsymbol{h}^{(t)} ight.$
5)$left. oldsymbol{C}^{(t + 1)}longrightarrowoldsymbol{f}^{(t + 1)}longrightarrowoldsymbol{h}^{(t)} ight.$
根据链式法则和全微分方程,有:
$oldsymbol{delta}_{h}^{(t)} = frac{partial Lleft( t ight)}{partialoldsymbol{h}^{(t)}} = frac{partial lleft( t ight)}{partialoldsymbol{h}^{(t)}} + left( frac{partialoldsymbol{C}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} ight)^{T}oldsymbol{delta}_{C}^{(t + 1)} + left( {frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}} ight)^{T}oldsymbol{delta}_{h}^{(t + 1)}$
注意:上式中特地用了$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$而不是$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$。因为在$oldsymbol{h}^{(t + 1)}$与$oldsymbol{h}^{(t)}$之间存在多条传播路径的情况下,$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} eq frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$。我们用$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$规定了从$oldsymbol{h}^{(t + 1)}$到$oldsymbol{h}^{(t)}$的误差传播路径必须是$left. oldsymbol{h}^{(t + 1)}longrightarrowoldsymbol{o}^{(t + 1)}longrightarrowoldsymbol{h}^{(t)} ight.$而不是其他的路径。如果是用$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$这个符号,则是默认要考虑所有从$oldsymbol{h}^{(t + 1)}$到$oldsymbol{h}^{(t)}$的误差传播路径。
上面这个递推公式需要解决三个问题,$frac{partial lleft( t ight)}{partialoldsymbol{h}^{(t)}}$,$left( frac{partialoldsymbol{C}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} ight)^{T}$和$left( {frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}} ight)^{T}$的求解。
对于$frac{partial lleft( t ight)}{partialoldsymbol{h}^{(t)}}$,根据vanilla RNN的前向传播和反向梯度推导 的4.2节,它满足:
$frac{partial lleft( t ight)}{partialoldsymbol{h}^{(t)}} = oldsymbol{V}^{T}left( {{hat{oldsymbol{y}}}^{(t)} - oldsymbol{y}^{(t)}} ight)$
我们接下来求$frac{partialoldsymbol{C}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$:
注意:因为下面的公式实在太长了,所以为节省空间,我们用“~”表示这个位置原本的数学表达式与上一行相同位置的数学表达式一样。
基于$oldsymbol{C}^{(t)} = oldsymbol{C}^{(t - 1)}igodotoldsymbol{f}^{(t)} + oldsymbol{a}^{(t)}igodotoldsymbol{i}^{(t)}$逐层展开,我们得到:
$doldsymbol{C}^{(t + 1)} = oldsymbol{C}^{(t)}igodot doldsymbol{f}^{({t + 1})} + oldsymbol{i}^{({t + 1})}igodot doldsymbol{a}^{({t + 1})} + oldsymbol{a}^{({t + 1})}igodot doldsymbol{i}^{({t + 1})}$
$= diagleft( oldsymbol{C}^{(t)} ight)doldsymbol{f}^{({t + 1})} + diagleft( oldsymbol{i}^{({t + 1})} ight)doldsymbol{a}^{({t + 1})} + diagleft( oldsymbol{a}^{({t + 1})} ight)doldsymbol{i}^{({t + 1})}$
$= diagleft( oldsymbol{C}^{(t)} ight)doldsymbol{f}^{({t + 1})} + diagleft( oldsymbol{a}^{({t + 1})} ight)doldsymbol{i}^{({t + 1})} + diagleft( oldsymbol{i}^{({t + 1})} ight)doldsymbol{a}^{({t + 1})}$
$left. = diagleft( {oldsymbol{C}^{(t)}igodotoldsymbol{f}^{({t + 1})}igodotleft( {1 - oldsymbol{f}^{({t + 1})}} ight)} ight)oldsymbol{W}_{f}doldsymbol{h}^{(t)} + ight.simleft. + ight.sim$
$left. = ight.simleft. + diagleft( {oldsymbol{a}^{({t + 1})}igodotoldsymbol{i}^{({t + 1})}igodotleft( {1 - oldsymbol{i}^{({t + 1})}} ight)} ight)oldsymbol{W}_{i}doldsymbol{h}^{(t)} + ight.sim$
因为${tanh}^{'}left( x ight) = left( {1 - {tanhleft( x ight)}^{2}} ight)$,所以:
$left. doldsymbol{C}^{(t + 1)} = ight.simleft. + ight.sim + diagleft( {oldsymbol{i}^{({t + 1})}igodotleft( {1 - {oldsymbol{a}^{({t + 1})}}^{2}} ight)} ight)oldsymbol{W}_{a}doldsymbol{h}^{(t)}$
整理上式我们得到:
$frac{partialoldsymbol{C}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} = diagleft( {oldsymbol{C}^{(t)}igodotoldsymbol{f}^{({t + 1})}igodotleft( {1 - oldsymbol{f}^{({t + 1})}} ight)} ight)oldsymbol{W}_{f} + diagleft( {oldsymbol{a}^{({t + 1})}igodotoldsymbol{i}^{({t + 1})}igodotleft( {1 - oldsymbol{i}^{({t + 1})}} ight)} ight)oldsymbol{W}_{i} + diagleft( {oldsymbol{i}^{({t + 1})}igodotleft( {1 - {oldsymbol{a}^{({t + 1})}}^{2}} ight)} ight)oldsymbol{W}_{a}$
接下来是$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}}$的推导过程:
$doldsymbol{h}^{({t + 1})} = tanhleft( oldsymbol{C}^{({t + 1})} ight)igodot doldsymbol{o}^{({t + 1})} = diagleft( {tanhleft( oldsymbol{C}^{({t + 1})} ight)} ight)diagleft( {oldsymbol{o}^{({t + 1})}igodotleft( {1 - oldsymbol{o}^{({t + 1})}} ight)} ight)dleft( {oldsymbol{W}_{o}oldsymbol{h}^{(t)}} ight) = diagleft( {tanhleft( oldsymbol{C}^{({t + 1})} ight)igodotoldsymbol{o}^{({t + 1})}igodotleft( {1 - oldsymbol{o}^{({t + 1})}} ight)} ight)oldsymbol{W}_{o}doldsymbol{h}^{(t)}$
所以$frac{partialoldsymbol{h}^{(t + 1)}}{partialoldsymbol{o}^{(t)}}frac{partialoldsymbol{o}^{(t + 1)}}{partialoldsymbol{h}^{(t)}} = diagleft( {tanhleft( oldsymbol{C}^{({t + 1})} ight)igodotoldsymbol{o}^{({t + 1})}igodotleft( {1 - oldsymbol{o}^{({t + 1})}} ight)} ight)$
于是我们现在得到了从$oldsymbol{delta}_{C}^{(t + 1)}$和$oldsymbol{delta}_{h}^{(t + 1)}$推得$oldsymbol{delta}_{h}^{(t)}$的递推公式。
接下来我们利用$oldsymbol{delta}_{h}^{(t)}$和$oldsymbol{delta}_{C}^{(t + 1)}$来推得$oldsymbol{delta}_{C}^{(t)}$:
根据LSTM的前向示意图,我们有:
$oldsymbol{delta}_{C}^{(t)} = left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{c}^{(t)}} ight)^{T}oldsymbol{delta}_{h}^{(t)} + {left( frac{partialoldsymbol{c}^{(t + 1)}}{partialoldsymbol{c}^{(t)}} ight)^{T}oldsymbol{delta}}_{C}^{(t + 1)}$
容易求得$frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{c}^{(t)}} = left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{C}^{(t)}} ight)^{T}frac{partial Lleft( t ight)}{partialoldsymbol{h}^{(t)}} = oldsymbol{o}^{(t)} odot left( {1 - {tanh}^{2}left( oldsymbol{C}^{(t)} ight)} ight)^{2}$
同样也容易求得$frac{partialoldsymbol{c}^{(t + 1)}}{partialoldsymbol{c}^{(t)}} = diagleft( oldsymbol{f}^{({t + 1})} ight)$
所以得到:
$oldsymbol{delta}_{C}^{(t)} = left( frac{partialoldsymbol{h}^{(t)}}{partialoldsymbol{c}^{(t)}} ight)^{T}oldsymbol{delta}_{h}^{(t)} + {left( frac{partialoldsymbol{c}^{(t + 1)}}{partialoldsymbol{c}^{(t)}} ight)^{T}oldsymbol{delta}}_{C}^{(t + 1)} = oldsymbol{o}^{(t)} odot left( {1 - {tanh}^{2}left( oldsymbol{C}^{(t)} ight)} ight)^{2}{odot oldsymbol{delta}}_{h}^{(t)} + oldsymbol{f}^{({t + 1})} odot oldsymbol{delta}_{C}^{(t + 1)}$
现在,我们能计算$oldsymbol{delta}_{h}^{(t)}$和$oldsymbol{delta}_{C}^{(t)}$了,有了它们,计算变量的梯度就比较容易了,这里只以计算$oldsymbol{W}_{f}$的梯度计算为例:
我们令${oldsymbol{z}^{(t)} = oldsymbol{W}}_{f}oldsymbol{h}^{(t - 1)} + oldsymbol{U}_{f}oldsymbol{x}^{(t)} + oldsymbol{b}_{f}$,则:
$frac{partial L}{partialoldsymbol{W}_{f}} = {sumlimits_{t = 1}^{T}left( frac{partialoldsymbol{C}_{t}}{partialoldsymbol{z}^{(t)}} ight)^{T}}frac{partial L}{partialoldsymbol{C}_{t}}left( oldsymbol{h}^{(t - 1)} ight)^{T}$
$doldsymbol{C}^{(t)} = oldsymbol{C}^{(t - 1)} odot doldsymbol{f}^{(t)} = diagleft( oldsymbol{C}^{({t - 1})} ight)left( {left( {oldsymbol{f}^{(t)} odot left( {1 - oldsymbol{f}^{(t)}} ight)} ight) odot doldsymbol{z}^{(t)}} ight) = diagleft( oldsymbol{C}^{({t - 1})} ight)left( {diagleft( {oldsymbol{f}^{(t)} odot left( {1 - oldsymbol{f}^{(t)}} ight)} ight)doldsymbol{z}^{(t)}} ight) = diagleft( {oldsymbol{f}^{(t)} odot left( {1 - oldsymbol{f}^{(t)}} ight) odot oldsymbol{C}^{({t - 1})}} ight)doldsymbol{z}^{(t)}$
所以$frac{partialoldsymbol{C}_{t}}{partialoldsymbol{z}^{(t)}} = diagleft( {oldsymbol{f}^{(t)} odot left( {1 - oldsymbol{f}^{(t)}} ight) odot oldsymbol{C}^{({t - 1})}} ight)$
所以得到:
$frac{partial L}{partialoldsymbol{W}_{f}} = {sumlimits_{t = 1}^{T}leftlbrack {oldsymbol{delta}_{C}^{(t)} odot oldsymbol{C}^{(t - 1)} odot oldsymbol{f}^{(t)} odot leftlbrack {1 - oldsymbol{f}^{(t)}} ight brack} ight brack}left( oldsymbol{h}^{(t - 1)} ight)^{T}$
其他变量的梯度按照上述类似的方式可依次求得,在这里不做过多叙述。
5.3 LSTM 能改善梯度消失的原因
首先需要明确的是,RNN 中的梯度消失/梯度爆炸和普通的 MLP 或者深层 CNN 中梯度消失/梯度爆炸的含义不一样。MLP/CNN 中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度$~g$= 各个时间步的梯度$g^{(t)}$之和。
因此,RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。
LSTM 中梯度的传播有很多条路径,但$oldsymbol{C}^{(t)} = oldsymbol{C}^{(t - 1)}igodotoldsymbol{f}^{(t)} + oldsymbol{a}^{(t)}igodotoldsymbol{i}^{(t)}$这条路径上只有逐元素相乘和相加的操作,梯度流最稳定;但是其他路径上梯度流与普通 RNN 类似,照样会发生相同的权重矩阵反复连乘。
由于总的远距离梯度 = 各条路径的远距离梯度之和,即便其他远距离路径梯度消失了,只要保证有一条远距离路径(就是上面说的那条高速公路)梯度不消失,总的远距离梯度就不会消失(正常梯度 + 消失梯度 = 正常梯度)。因此 LSTM 通过改善一条路径上的梯度问题拯救了总体的远距离梯度。
如果本文对您有所帮助的话,不妨点下“推荐”让它能帮到更多的人,谢谢。
参考资料
- https://www.zhihu.com/question/34878706/answer/665429718
- https://weberna.github.io/blog/2017/11/15/LSTM-Vanishing-Gradients.html
- https://www.cnblogs.com/pinard/p/6519110.html
(欢迎转载,转载请注明出处。欢迎留言或沟通交流: lxwalyw@gmail.com)