zoukankan      html  css  js  c++  java
  • RNN,LSTM,GRU基本原理的个人理解

    记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解

    RNN

    Recurrent Neural Networks,循环神经网络
    (注意区别于recursive neural network,递归神经网络)

    为了解决DNN存在着无法对时间序列上的变化进行建模的问题(如自然语言处理、语音识别、手写体识别),出现的另一种神经网络结构——循环神经网络RNN。

    RNN结构

    第tt层神经元的输入,除了其自身的输入xtxt,还包括上一层神经元的隐含层输出st−1st−1
    每一层的参数U,W,V都是共享的


    每一层并不一定都得有输入和输出,如对句子进行情感分析是多到一,文本翻译多到多,图片描述一到多
    数学描述

    (以下开始符号统一)
    回忆一下单隐含层的前馈神经网络
    输入为X∈Rn×xX∈Rn×x(n个维度为x的向量)
    隐含层输出为
    H=ϕ(XWxh+bh)
    H=ϕ(XWxh+bh)

    输出层输入H∈Rn×hH∈Rn×h
    输出为
    Y^=softmax(HWhy+by)
    Y^=softmax(HWhy+by)

    现在对XX、HH、YY都加上时序下标
    同时引入一个新权重Whh∈Rh×hWhh∈Rh×h
    得到RNN表达式
    Ht=ϕ(XtWxh+Ht−1Whh+bh)
    Ht=ϕ(XtWxh+Ht−1Whh+bh)
    Y^t=softmax(HtWhy+by)
    Y^t=softmax(HtWhy+by)

    H0H0通常置零
    深层RNN和双向RNN


    通过时间反向传播和随之带来的问题

    输入为xt∈Rxxt∈Rx
    不考虑偏置
    隐含层变量为
    ht=ϕ(Whxxt+Whhht−1)
    ht=ϕ(Whxxt+Whhht−1)

    输出层变量为
    ot=Wyhht
    ot=Wyhht

    则损失函数为
    L=1T∑t=1Tℓ(ot,yt)
    L=1T∑t=1Tℓ(ot,yt)
    以一个三层为例

    三个参数更新公式为
    Whx=Whx−η∂L∂Whx
    Whx=Whx−η∂L∂Whx
    Whh=Whh−η∂L∂Whh
    Whh=Whh−η∂L∂Whh
    Wyh=Wyh−η∂L∂Wyh
    Wyh=Wyh−η∂L∂Wyh

    明显的
    ∂L∂ot=∂ℓ(ot,yt)T⋅∂ot
    ∂L∂ot=∂ℓ(ot,yt)T⋅∂ot

    根据链式法则
    ∂L∂Wyh=∑t=1Tprod(∂L∂ot,∂ot∂Wyh)=∑t=1T∂L∂oth⊤t
    ∂L∂Wyh=∑t=1Tprod(∂L∂ot,∂ot∂Wyh)=∑t=1T∂L∂otht⊤

    先计算目标函数有关最终时刻隐含层变量的梯度
    ∂L∂hT=prod(∂L∂oT,∂oT∂hT)=W⊤yh∂L∂oT
    ∂L∂hT=prod(∂L∂oT,∂oT∂hT)=Wyh⊤∂L∂oT

    假设ϕ(x)=xϕ(x)=x(RNN中用激活函数relu还是tanh众说纷纭,有点玄学)
    ∂L∂ht=prod(∂L∂ht+1,∂ht+1∂ht)+prod(∂L∂ot,∂ot∂ht)=W⊤hh∂L∂ht+1+W⊤yh∂L∂ot
    ∂L∂ht=prod(∂L∂ht+1,∂ht+1∂ht)+prod(∂L∂ot,∂ot∂ht)=Whh⊤∂L∂ht+1+Wyh⊤∂L∂ot

    通项为
    ∂L∂ht=∑i=tT(W⊤hh)T−iW⊤yh∂L∂oT+t−i
    ∂L∂ht=∑i=tT(Whh⊤)T−iWyh⊤∂L∂oT+t−i
    注意上式,当每个时序训练数据样本的时序长度T较大或者时刻t较小,目标函数有关隐含层变量梯度较容易出现衰减和爆炸

    ∂L∂Whx=∑t=1Tprod(∂L∂ht,∂ht∂Whx)=∑t=1T∂L∂htx⊤t
    ∂L∂Whx=∑t=1Tprod(∂L∂ht,∂ht∂Whx)=∑t=1T∂L∂htxt⊤
    ∂L∂Whh=∑t=1Tprod(∂L∂ht,∂ht∂Whh)=∑t=1T∂L∂hth⊤t−1
    ∂L∂Whh=∑t=1Tprod(∂L∂ht,∂ht∂Whh)=∑t=1T∂L∂htht−1⊤
    梯度裁剪

    为了应对梯度爆炸,一个常用的做法是如果梯度特别大,那么就投影到一个比较小的尺度上。θθ为设定的裁剪“阈值”,为标量,若梯度的范数大于此阈值,将梯度缩小,若梯度的范数小于此阈值,梯度不变
    g=min(θ∥g∥,1)g
    g=min(θ‖g‖,1)g
    LSTM

    RNN的隐含层变量梯度可能会出现衰减或爆炸。虽然梯度裁剪可以应对梯度爆炸,但无法解决梯度衰减。因此,给定一个时间序列,例如文本序列,循环神经网络在实际中其实较难捕捉两个时刻距离较大的文本元素(字或词)之间的依赖关系。
    LSTM(long short-term memory)由Hochreiter和Schmidhuber在1997年被提出。

    LSTM结构

    这里两张图先不用细看,先着重记住公式后再回来看


    数学描述

    (同上,符号统一)
    设隐含状态长度hh,tt时刻输入Xt∈Rn×xXt∈Rn×x(xx维)及t−1t−1时刻隐含状态Ht−1∈Rn×hHt−1∈Rn×h,
    输入门,遗忘门,输出门,候选细胞如下

    It=σ(XtWxi+Ht−1Whi+bi)
    It=σ(XtWxi+Ht−1Whi+bi)
    Ft=σ(XtWxf+Ht−1Whf+bf)
    Ft=σ(XtWxf+Ht−1Whf+bf)
    Ot=σ(XtWxo+Ht−1Who+bo)
    Ot=σ(XtWxo+Ht−1Who+bo)
    C~t=tanh(XtWxc+Ht−1Whc+bc)
    C~t=tanh(XtWxc+Ht−1Whc+bc)
    (思考侯选细胞激活函数的不同)
    记忆细胞
    Ct=Ft⊙Ct−1+It⊙C~t
    Ct=Ft⊙Ct−1+It⊙C~t

    想象,如果遗忘门一直近似1且输入门一直近似0,过去的细胞将一直通过时间保存并传递至当前时刻
    隐含状态
    Ht=Ot⊙tanh(Ct)
    Ht=Ot⊙tanh(Ct)

    输出同RNN
    Y^=softmax(HWhy+by)
    Y^=softmax(HWhy+by)
    GRU

    由Cho、van Merrienboer、 Bahdanau和Bengio在2014年提出,比LSTM少一个门控,实验结果却相当

    GRU结构

    数学描述

    设隐含状态长度hh,tt时刻输入Xt∈Rn×xXt∈Rn×x(xx维)及t−1t−1时刻隐含状态Ht−1∈Rn×hHt−1∈Rn×h,
    重置门,更新门如下
    Rt=σ(XtWxr+Ht−1Whr+br)
    Rt=σ(XtWxr+Ht−1Whr+br)
    Zt=σ(XtWxz+Ht−1Whz+bz)
    Zt=σ(XtWxz+Ht−1Whz+bz)

    候选隐含状态
    H~t=tanh(XtWxh+Rt⊙Ht−1Whh+bh)
    H~t=tanh(XtWxh+Rt⊙Ht−1Whh+bh)

    隐含状态
    Ht=Zt⊙Ht−1+(1−Zt)⊙H~t
    Ht=Zt⊙Ht−1+(1−Zt)⊙H~t

    输出
    Y^=softmax(HWhy+by)
    Y^=softmax(HWhy+by)
    (无力吐槽csdn了,预览和实际用的不一套渲染,公式丑死)
    ---------------------
    作者:lily_knight
    来源:CSDN
    原文:https://blog.csdn.net/qq_38210185/article/details/79376053
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    〖Linux〗Kubuntu设置打开应用时就只在打开时的工作区显示
    〖Linux〗Kubuntu, the application 'Google Chrome' has requested to open the wallet 'kdewallet'解决方法
    unity, dll is not allowed to be included or could not be found
    android check box 自定义图片
    unity, ios skin crash
    unity, Collider2D.bounds的一个坑
    unity, ContentSizeFitter立即生效
    类里的通用成员函数应声明为static
    unity, Gizmos.DrawMesh一个坑
    直线切割凹多边形
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11069130.html
Copyright © 2011-2022 走看看