zoukankan      html  css  js  c++  java
  • 随时间的反向传播算法 BPTT

    本文转自:https://www.cntofu.com/book/85/dl/rnn/bptt.md

    随时间反向传播(BPTT)算法


    先简单回顾一下RNN的基本公式:

     

    st=tanh(Uxt+Wst1)st=tanh⁡(Uxt+Wst−1)

     

    y^t=softmax(Vst)y^t=softmax(Vst)

    RNN的损失函数定义为交叉熵损失:

     

    Et(yt,y^t)=ytlogy^tEt(yt,y^t)=−ytlog⁡y^t

     

    E(y,y^)=tEt(yt,y^t)=tytlogy^tE(y,y^)=∑tEt(yt,y^t)=−∑tytlog⁡y^t

     

    ytyt

    是时刻t的样本实际值, 

    y^ty^t

    是预测值,我们通常把整个序列作为一个训练样本,所以总的误差就是每一步的误差的加和。我们的目标是计算损失函数的梯度,然后通过梯度下降方法学习出所有的参数U, V, W。比如:

    EW=tEtW∂E∂W=∑t∂Et∂W

    为了更好理解BPTT我们来推导一下公式:

    前向 前向传播1:

     

    a0=x0ua0=x0∗u

     

    b0=s1wb0=s−1∗w

     

    z0=a0+b0+kz0=a0+b0+k

     

    s0=func(z0)s0=func(z0)

     (

    funcfunc

     是 sig或者tanh)

    前向 前向传播2:

     

    a1=x1ua1=x1∗u

     

    b1=s0wb1=s0∗w

     

    z1=a1+b1+kz1=a1+b1+k

     

    s1=func(z1)s1=func(z1)

    (

    funcfunc

     是 sig 或者tanh)

     

    q=s1v1q=s1∗v1

    $$z_t = ux_t + ws_{t-1} + k$$

     

    st=func(zt)st=func(zt)

    输出 层:

     

    o=func(q)o=func(q)

    (

    funcfunc

     是 softmax)

     

    E=func(o)E=func(o)

    (

    funcfunc

     是 x-entropy)

    下面 是U的推导

     

    E/u=E/u1+E/u0∂E/∂u=∂E/∂u1+∂E/∂u0

     

    E/u1=E/oo/qq/s1s1/z1z1/a1a1/u1∂E/∂u1=∂E/∂o∗∂o/∂q∗∂q/∂s1∗∂s1/∂z1∗∂z1/∂a1∗∂a1/∂u1

     

    E/u0=E/oo/qq/s1s1/z1z1/b1b1/s0s0/dz0z0/a0a0/u0∂E/∂u0=∂E/∂o∗∂o/∂q∗∂q/∂s1∗∂s1/∂z1∗∂z1/∂b1∗∂b1/∂s0∗∂s0/dz0∗∂z0/∂a0∗∂a0/∂u0

     

    E/u=E/oo/qv1s1/z1((1x1)+(1w1s0/z01x0))∂E/∂u=∂E/∂o∗∂o/∂q∗v1∗∂s1/∂z1∗((1∗x1)+(1∗w1∗∂s0/∂z0∗1∗x0))

     

    E/u=E/oo/qv1s1/z1(x1+w1s0/z0x0)∂E/∂u=∂E/∂o∗∂o/∂q∗v1∗∂s1/∂z1∗(x1+w1∗∂s0/∂z0∗x0)

    W参数的推导如下

     

    E/w=E/oo/qv1s1/z1(s0+w1s0/z0s1)∂E/∂w=∂E/∂o∗∂o/∂q∗v1∗∂s1/∂z1∗(s0+w1∗∂s0/∂z0∗s−1)

    总结

     

    Lu=tLut=Loos1s1u1+Loos1s1s0s0u0∂L∂u=∑t∂L∂ut=∂L∂o∂o∂s1∂s1∂u1+∂L∂o∂o∂s1∂s1∂s0∂s0∂u0

     

    Lw=tLwt=Loos1s1w1+Loos1s1s0s0w0∂L∂w=∑t∂L∂wt=∂L∂o∂o∂s1∂s1∂w1+∂L∂o∂o∂s1∂s1∂s0∂s0∂w0

     

    xtxt

    是时间t的输入

  • 相关阅读:
    消息中间件——RabbitMQ(六)理解Exchange交换机核心概念!
    消息中间件——RabbitMQ(五)快速入门生产者与消费者,SpringBoot整合RabbitMQ!
    消息中间件——RabbitMQ(四)命令行与管控台的基本操作!
    消息中间件——RabbitMQ(三)理解RabbitMQ核心概念和AMQP协议!
    LayUI的基本使用
    Git报错:Your branch is up to date with 'origin/master'.
    Git报错:Please tell me who you are.
    Git报错:Permission denied (publickey)
    在 windows 上安装 git 2.22
    在 windows 上安装 git 2.15
  • 原文地址:https://www.cnblogs.com/carlber/p/11084932.html
Copyright © 2011-2022 走看看