zoukankan      html  css  js  c++  java
  • 花书BPTT公式推导

    花书第10.2.2节的计算循环神经网络的梯度看了好久,总算是把公式的推导给看懂了,记录一下过程。

    首先,对于一个普通的RNN来说,其前向传播过程为:

    $$ extbf{a}^{(t)}= extbf{b}+ extbf{Wh}^{t-1}+ extbf{Ux}^{(t)}$$

    $$ extbf{h}^t=tanh( extbf{a}^{(t)})$$

    $$ extbf{o}^{(t)} = extbf{c} + extbf{V} extbf{h}^{(t)}$$

    $$hat{ extbf{y}}^{(t)} = softmax( extbf{o}^{(t)})$$

    先介绍一下等下计算过程中会用到的偏导数:

    $$h = tanh(a) = frac{e^a-e^{-a}}{e^a+e^{-a}}$$

    $$frac{partial extbf{h}}{partial extbf{a}} = diag(1- extbf{h}^2)$$

    另一个,当$ extbf{y}$采用one-hot并且损失函数$L$为交叉熵时:

    $$frac{partial L}{partial extbf{o}^{(t)}} = frac{partial L}{partial L^{(t)}}frac{partial L^{(t)}}{partial extbf{o}^{t}} = hat{ extbf{y}}^{(t)}- extbf{y}^{(t)}$$

    【注】这里涉及到softmax求导的规律,如果不懂的话可以看看:传送门

    接下来从RNN的尾部开始,逐步计算隐藏状态$ extbf{h}^t$的梯度。如果$ au$是最后的时间步,$ extbf{h}^{( au)}$就是最后的隐藏输出。

    $$frac{partial L}{partial extbf{h}^{( au)}} = frac{partial L}{partial extbf{o}^{( au)}}frac{partial extbf{o}^{( au)}}{partial extbf{h}^{( au)}}= extbf{V}^T(hat{ extbf{y}}^{( au)}- extbf{y}^{( au)})$$

    然后一步步往前计算$ extbf{h}^t$的梯度,注意$ extbf{h}^{(t)}(t< au)$同时有$ extbf{o}^{(t)}$和$ extbf{h}^{(t+1)}$两个后续节点,所以:

    $$frac{partial L}{partial extbf{h}^{(t)}}=(frac{partial extbf{h}^{(t+1)}}{partial extbf{h}^{(t)}})^Tfrac{partial L}{partial extbf{h}^{(t+1)}}+(frac{partial extbf{o}^{(t)}}{partial extbf{h}^{(t)}})^Tfrac{partial L}{partial extbf{o}^{(t)}}=(frac{partial extbf{h}^{(t+1)}}{partial extbf{a}^{(t+1)}} frac{partial extbf{a}^{(t+1)}}{partial extbf{h}^{(t)}})^T frac{partial L}{partial extbf{h}^{(t+1)}}+ extbf{V}^T(hat{ extbf{y}}^{(t)}- extbf{y}^{(t)})= extbf{W}^T(diag(1-( extbf{h}^{(t+1)})^2))frac{partial L}{partial extbf{h}^{(t+1)}}+ extbf{V}^T(hat{ extbf{y}}^{(t)}- extbf{y}^{(t)})$$

    【注】这里的结果和花书有点不一样,不知道是花书有错误还是我这里错了?

    剩下的参数计算起来就简单多了:

    $$frac{partial L}{partial extbf{W}} = sum_{t=1}^{ au}frac{partial L}{partial extbf{h}^{(t)}}frac{partial extbf{h}^{(t)}}{partial extbf{W}} = sum_{t=1}^{ au}frac{partial L}{partial extbf{h}^{(t)}}frac{partial extbf{h}^{(t)}}{partial extbf{a}^{(t)}}frac{partial extbf{a}^{(t)}}{partial extbf{W}} = sum_{t=1}^{ au}diag(1-( extbf{h}^{(t)})^2)frac{partial L}{partial extbf{h}^{(t)}}( extbf{h}^{(t-1)})^T$$

    $$frac{partial L}{partial extbf{b}}= sumlimits_{t=1}^{ au}diag(1-( extbf{h}^{(t)})^2)frac{partial L}{partial extbf{h}^{(t)}}$$

    $$frac{partial L}{partial extbf{U}} =sumlimits_{t=1}^{ au}diag(1-( extbf{h}^{(t)})^2)frac{partial L}{partial extbf{h}^{(t)}}( extbf{x}^{(t)})^T$$

    $$frac{partial L}{partial extbf{c}} = sumlimits_{t=1}^{ au}frac{partial L^{(t)}}{partial extbf{c}}  = sumlimits_{t=1}^{ au}hat{ extbf{y}}^{(t)} - extbf{y}^{(t)}$$

    $$frac{partial L}{partial extbf{V}} =sumlimits_{t=1}^{ au}frac{partial L^{(t)}}{partial extbf{V}}  = sumlimits_{t=1}^{ au}(hat{ extbf{y}}^{(t)} - extbf{y}^{(t)}) ( extbf{h}^{(t)})^T$$

    参考

    【1】RNN前向方向传播(花书《深度学习》10.2循环神经网络)

    【2】循环神经网络(RNN)模型与前向反向传播算法

  • 相关阅读:
    下载windows原装镜像的官方网站
    Typora快捷键
    UOS使用ZSH终端教程
    UOS每日折腾、调教、美化
    AMD64和X86_64
    CPU架构
    23种设计模式---单例设计模式(精华)
    java学习day32-Servlet上下文--ServletContext
    java学习day32-Servlet过滤器-Filter
    java学习day32-JSP标签技术-JSTL标签库
  • 原文地址:https://www.cnblogs.com/zyb993963526/p/13797144.html
Copyright © 2011-2022 走看看