zoukankan      html  css  js  c++  java
  • RNN中的梯度消失爆炸原因

    RNN中的梯度消失/爆炸原因

    梯度消失/梯度爆炸是深度学习中老生常谈的话题,这篇博客主要是对RNN中的梯度消失/梯度爆炸原因进行公式层面上的直观理解。

    首先,上图是RNN的网络结构图,((x_1, x_2, x_3, …, ))是输入的序列,(X_t)表示时间步为(t)时的输入向量。假设我们总共有(k)个时间步,用第(k)个时间步的输出(H_k)作为输出(实际上每个时间步都有输出,这里仅考虑(H_k)),用(E_k)表示损失。

    其中,(C_{t}= anh left(W_{c} C_{t-1}+W_{x} X_{t} ight))

    从上式可以看出 (W_x)(W_c)其实是差不多的,记(W=[W_c, W_x]),那么求偏导可以得到:

    (egin{aligned} frac{partial E_{k}}{partial W}=& frac{partial E_{k}}{partial H_{k}} frac{partial H_{k}}{partial C_{k}} frac{partial C_{k}}{partial C_{k-1}} ldots frac{partial C_{2}}{partial C_{1}} frac{partial C_{1}}{partial W}=\ & frac{partial E_{k}}{partial H_{k}} frac{partial H_{k}}{partial C_{k}}left(prod_{t=2}^{k} frac{partial C_{t}}{partial C_{t-1}} ight) frac{partial C_{1}}{partial W} end{aligned})

    其中的累乘部分为:

    (egin{aligned} frac{partial C_{t}}{partial c_{t-1}}=& anh ^{prime}left(W_{c} C_{t-1}+W_{x} X_{t} ight) cdot frac{d}{d C_{t-1}}left[W_{c} C_{t-1}+W_{x} X_{t} ight]=\ & anh ^{prime}left(W_{c} C_{t-1}+W_{x} X_{t} ight) cdot W_{c} end{aligned})

    将该式代入上式有:

    (frac{partial E_{k}}{partial W}=frac{partial E_{k}}{partial H_{k}} frac{partial H_{k}}{partial C_{k}}left(prod_{t=2}^{k} anh ^{prime}left(W_{c} C_{t-1}+W_{x} X_{t} ight) cdot W_{c} ight) frac{partial c_{1}}{partial W})

    观察这个式子,和上篇文章中一样,因为链式法则,出现了累乘项,因为tanh的导数 <= 1,所以,当k很大的时候,上式的值是趋向于0的。(<1的数多次相乘),也就是:

    (Pi_{t=2}^{k} anh ^{prime}left(W_{c} C_{t-1}+w_{x} X_{t} ight) cdot W_{c} ightarrow 0,) so (frac{partial E_{k}}{partial W} ightarrow 0)

    此时,权重更新公式:

    (W leftarrow W-alpha frac{partial E_{k}}{partial W} approx W)

    也就是说,RNN很容易出现梯度消失现象,使得参数更新缓慢,甚至是停止更新。

  • 相关阅读:
    vue基础知识
    git的创建使用
    使用express搭建服务器框架
    日常训练
    今日收获
    今日收获
    今日收获
    今日收获
    今日收获
    今日收获
  • 原文地址:https://www.cnblogs.com/Elaine-DWL/p/11240209.html
Copyright © 2011-2022 走看看