花书第10.2.2节的计算循环神经网络的梯度看了好久,总算是把公式的推导给看懂了,记录一下过程。
首先,对于一个普通的RNN来说,其前向传播过程为:
$$ extbf{a}^{(t)}= extbf{b}+ extbf{Wh}^{t-1}+ extbf{Ux}^{(t)}$$
$$ extbf{h}^t=tanh( extbf{a}^{(t)})$$
$$ extbf{o}^{(t)} = extbf{c} + extbf{V} extbf{h}^{(t)}$$
$$hat{ extbf{y}}^{(t)} = softmax( extbf{o}^{(t)})$$
先介绍一下等下计算过程中会用到的偏导数:
$$h = tanh(a) = frac{e^a-e^{-a}}{e^a+e^{-a}}$$
$$frac{partial extbf{h}}{partial extbf{a}} = diag(1- extbf{h}^2)$$
另一个,当$ extbf{y}$采用one-hot并且损失函数$L$为交叉熵时:
$$frac{partial L}{partial extbf{o}^{(t)}} = frac{partial L}{partial L^{(t)}}frac{partial L^{(t)}}{partial extbf{o}^{t}} = hat{ extbf{y}}^{(t)}- extbf{y}^{(t)}$$
【注】这里涉及到softmax求导的规律,如果不懂的话可以看看:传送门
接下来从RNN的尾部开始,逐步计算隐藏状态$ extbf{h}^t$的梯度。如果$ au$是最后的时间步,$ extbf{h}^{( au)}$就是最后的隐藏输出。
$$frac{partial L}{partial extbf{h}^{( au)}} = frac{partial L}{partial extbf{o}^{( au)}}frac{partial extbf{o}^{( au)}}{partial extbf{h}^{( au)}}= extbf{V}^T(hat{ extbf{y}}^{( au)}- extbf{y}^{( au)})$$
然后一步步往前计算$ extbf{h}^t$的梯度,注意$ extbf{h}^{(t)}(t< au)$同时有$ extbf{o}^{(t)}$和$ extbf{h}^{(t+1)}$两个后续节点,所以:
$$frac{partial L}{partial extbf{h}^{(t)}}=(frac{partial extbf{h}^{(t+1)}}{partial extbf{h}^{(t)}})^Tfrac{partial L}{partial extbf{h}^{(t+1)}}+(frac{partial extbf{o}^{(t)}}{partial extbf{h}^{(t)}})^Tfrac{partial L}{partial extbf{o}^{(t)}}=(frac{partial extbf{h}^{(t+1)}}{partial extbf{a}^{(t+1)}} frac{partial extbf{a}^{(t+1)}}{partial extbf{h}^{(t)}})^T frac{partial L}{partial extbf{h}^{(t+1)}}+ extbf{V}^T(hat{ extbf{y}}^{(t)}- extbf{y}^{(t)})= extbf{W}^T(diag(1-( extbf{h}^{(t+1)})^2))frac{partial L}{partial extbf{h}^{(t+1)}}+ extbf{V}^T(hat{ extbf{y}}^{(t)}- extbf{y}^{(t)})$$
【注】这里的结果和花书有点不一样,不知道是花书有错误还是我这里错了?
剩下的参数计算起来就简单多了:
$$frac{partial L}{partial extbf{W}} = sum_{t=1}^{ au}frac{partial L}{partial extbf{h}^{(t)}}frac{partial extbf{h}^{(t)}}{partial extbf{W}} = sum_{t=1}^{ au}frac{partial L}{partial extbf{h}^{(t)}}frac{partial extbf{h}^{(t)}}{partial extbf{a}^{(t)}}frac{partial extbf{a}^{(t)}}{partial extbf{W}} = sum_{t=1}^{ au}diag(1-( extbf{h}^{(t)})^2)frac{partial L}{partial extbf{h}^{(t)}}( extbf{h}^{(t-1)})^T$$
$$frac{partial L}{partial extbf{b}}= sumlimits_{t=1}^{ au}diag(1-( extbf{h}^{(t)})^2)frac{partial L}{partial extbf{h}^{(t)}}$$
$$frac{partial L}{partial extbf{U}} =sumlimits_{t=1}^{ au}diag(1-( extbf{h}^{(t)})^2)frac{partial L}{partial extbf{h}^{(t)}}( extbf{x}^{(t)})^T$$
$$frac{partial L}{partial extbf{c}} = sumlimits_{t=1}^{ au}frac{partial L^{(t)}}{partial extbf{c}} = sumlimits_{t=1}^{ au}hat{ extbf{y}}^{(t)} - extbf{y}^{(t)}$$
$$frac{partial L}{partial extbf{V}} =sumlimits_{t=1}^{ au}frac{partial L^{(t)}}{partial extbf{V}} = sumlimits_{t=1}^{ au}(hat{ extbf{y}}^{(t)} - extbf{y}^{(t)}) ( extbf{h}^{(t)})^T$$
参考