zoukankan      html  css  js  c++  java
  • RNN、lstm和GRU推导

    RNN:(Recurrent Neural Networks)循环神经网络

    • t层神经元的输入,除了其自身的输入xt,还包括上一层神经元的隐含层输出st1
    • 每一层的参数U,W,V都是共享的

    lstm:长短时记忆网络,是一种改进后的循环神经网络,可以解决RNN无法处理的长距离依赖问题。

     

     原始 RNN 的隐藏层只有一个状态,即h,它对于短期的输入非常敏感。再增加一个状态,即c,让它来保存长期的状态,称为单元状态(cell state)。

     

    按照时间维度展开如下所示:

     

     在t时刻,lstm的输入有三个:当前时刻的网络的输入值、上时刻lstm的输出值、以及上一时刻的单元状态;lstm的输出有两个:当前时刻lstm的输出值、和当前时刻的单元状态。使用三个控制开关控制长期状态c:

     

    在算法中利用门实现三个状态的功能:

    门就是一个全连接层,输入的是一个向量,输出是一个0到1之间的实数向量。

     

     

     门控制的原理:用门的输出向量按照元素乘以我们需要控制的那个向量,门的输出不是0就是1,0乘以任何向量都是0代表不通过,1乘以任何向量不会发生改变。

    遗忘门的计算方式:

     

     

     遗忘门:决定了上一时刻的单元状态c_t-1有多少保留到了c_t当前状态,Wf 是遗忘门的权重矩阵,[ht-1,xt]表示将两个变量拼接起来,bf是遗忘门的偏置项,是sigmoid函数。

     

    输入门的计算:

     

     输入门:决定了当前时刻网络的输入x_t有多少保存到单元状态c_t.

    根据上一次的输出和本次输入计算当前输入的单元状态:

     

     

    当前时刻的单元状态c_t的计算由上一次的单元状态c_t-1乘以按元素乘以遗忘门ft,在用当前输入的单元状态c_t乘以输入门i_t,将两个积加和,可以将长期记忆和当前记忆结合起来形成新的单元状态。由于遗忘门的控制可以保存很久很久的信息。由于输入门的控制可以避免无关紧要的内容进入记忆。

    目标是要学习8组参数:

     

    权重矩阵是由两个矩阵拼接而成的。误差项是沿时间的反向传播,定义t时刻的误差项:

     

     

     权重矩阵计算公式如下:

     

    总体流程总结:

    原始输入循环体的是当前输入x_t和上前一步的输出h_{t-1},以及上一步的状态C_{t-1},

    x_th_{t-1}先遇到遗忘门(forget gate)

    f_{t}=sigmoid(W_f[h_{t-1},x_t]+b_f)

    经过遗忘门的函数之后产生一个0到1之间的输出f_t,代表遗忘多少之前的状态C_{t-1},当f_t为0时代表全部遗忘,1代表完全保持。

    另外一条路线上,x_th_{t-1}又会遇见输入门(input gate),输入门会决定记忆哪些值:

    i_t=sigmoid(W_i[h_{t-1},x_t]+b+i)

    另外同时经过tanh函数会产生一个新的状态C'_t

    C'_t=tanh(W_C[h_{t-1},x_t]+b_C)

    这个时候,由C_{t-1},f_t,C'_t,i_t就可以决定循环体的当前状态C_t了:

    C_t=f_t*C_{t-1}+i_t*C'_t

    有了当前的状态,自然就可以去输出门(output gate)了:

    o_t=sigmoid(W_o[h_{t-1},x_t]+b_o)

    h_t=o_t*tanh(C_t)

    从上面的公式,我们容易发现,每个门的形态是一样的,都是通过sigmoid函数作用于当前的输入x_t和前一时刻的输出h_{t-1}产生一个0到1的数值,以此来决定通过多少信息。

    GRU:gate recurrent unit ,门控循环单元(GRU)。GRU 旨在解决标准 RNN 中出现的梯度消失问题。GRU 也可以被视为 LSTM 的变体。

    GRU 背后的原理与 LSTM 非常相似,即用门控机制控制输入、记忆等信息而在当前时间步做出预测,表达式由以下给出:

    GRU 有两个门,即一个重置门(reset gate)和一个更新门(update gate)。从直观上来说,重置门决定了如何将新的输入信息与前面的记忆相结合,更新门定义了前面记忆保存到当前时间步的量。如果我们将重置门设置为 1,更新门设置为 0,那么我们将再次获得标准 RNN 模型。使用门控机制学习长期依赖关系的基本思想和 LSTM 一致,但还是有一些关键区别:

    • GRU 有两个门(重置门与更新门),而 LSTM 有三个门(输入门、遗忘门和输出门)。
    • GRU 并不会控制并保留内部记忆(c_t),且没有 LSTM 中的输出门。
    • LSTM 中的输入与遗忘门对应于 GRU 的更新门,重置门直接作用于前面的隐藏状态。
    • 在计算输出时并不应用二阶非线性

     

    1.更新门

    在时间步 t,我们首先需要使用以下公式计算更新门 z_t:

    其中 x_t 为第 t 个时间步的输入向量,即输入序列 X 的第 t 个分量,它会经过一个线性变换(与权重矩阵 W(z) 相乘)。h_(t-1) 保存的是前一个时间步 t-1 的信息,它同样也会经过一个线性变换。更新门将这两部分信息相加并投入到 Sigmoid 激活函数中,因此将激活结果压缩到 0 到 1 之间。以下是更新门在整个单元的位置与表示方法。更新门帮助模型决定到底要将多少过去的信息传递到未来,或到底前一时间步和当前时间步的信息有多少是需要继续传递的。这一点非常强大,因为模型能决定从过去复制所有的信息以减少梯度消失的风险。我们随后会讨论更新门的使用方法,现在只需要记住 z_t 的计算公式就行。

    2. 重置门

    本质上来说,重置门主要决定了到底有多少过去的信息需要遗忘,我们可以使用以下表达式计算:

    该表达式与更新门的表达式是一样的,只不过线性变换的参数和用处不一样而已。下图展示了该运算过程的表示方法。如前面更新门所述,h_(t-1) 和 x_t 先经过一个线性变换,再相加投入 Sigmoid 激活函数以输出激活值。

    3. 当前记忆内容

    现在我们具体讨论一下这些门控到底如何影响最终的输出。在重置门的使用中,新的记忆内容将使用重置门储存过去相关的信息,它的计算表达式为:

    输入 x_t 与上一时间步信息 h_(t-1) 先经过一个线性变换,即分别右乘矩阵 W 和 U。

    计算重置门 r_t 与 Uh_(t-1) 的 Hadamard 乘积,即 r_t 与 Uh_(t-1) 的对应元素乘积。因为前面计算的重置门是一个由 0 到 1 组成的向量,它会衡量门控开启的大小。例如某个元素对应的门控值为 0,那么它就代表这个元素的信息完全被遗忘掉。该 Hadamard 乘积将确定所要保留与遗忘的以前信息。

    将这两部分的计算结果相加再投入双曲正切激活函数中。

    4. 当前时间步的最终记忆

    在最后一步,网络需要计算 h_t,该向量将保留当前单元的信息并传递到下一个单元中。在这个过程中,我们需要使用更新门,它决定了当前记忆内容 h'_t 和前一时间步 h_(t-1) 中需要收集的信息是什么。这一过程可以表示为:

    z_t 为更新门的激活结果,它同样以门控的形式控制了信息的流入。z_t 与 h_(t-1) 的 Hadamard 乘积表示前一时间步保留到最终记忆的信息,该信息加上当前记忆保留至最终记忆的信息就等于最终门控循环单元输出的内容。

    双向RNN:Bidirectional RNN(双向RNN)假设当前t的输出不仅仅和之前的序列有关,并且 还与之后的序列有关,例如:预测一个语句中缺失的词语那么需要根据上下文进 行预测;Bidirectional RNN是一个相对简单的RNNs,由两个RNNs上下叠加在 一起组成。输出由这两个RNNs的隐藏层的状态决定。有些情况下,当前的输出不只依赖于之前的序列元素,还可能依赖之后的序列元素; 比如做完形填空,机器翻译等应用。

    # 开始网络构建
        # 1. 输入的数据格式转换
        # X格式:[batch_size, time_steps, input_size]
        X = tf.reshape(_X, shape=[-1, timestep_size, input_size])
    
        # 单层LSTM RNN
        # 2. 定义Cell
        lstm_cell_fw = tf.nn.rnn_cell.LSTMCell(num_units=hidden_size, reuse=tf.get_variable_scope().reuse)
        gru_cell_bw = tf.nn.rnn_cell.GRUCell(num_units=hidden_size, reuse=tf.get_variable_scope().reuse)
    
        # 3. 单层的RNN网络应用
        init_state_fw = lstm_cell_fw.zero_state(batch_size, dtype=tf.float32)
        init_state_bw = gru_cell_bw.zero_state(batch_size, dtype=tf.float32)
    
        # 3. 动态构建双向的RNN网络
        """
        bidirectional_dynamic_rnn(
            cell_fw: 前向的rnn cell
            , cell_bw:反向的rnn cell
            , inputs:输入的序列
            , sequence_length=None
            , initial_state_fw=None:前向rnn_cell的初始状态
            , initial_state_bw=None:反向rnn_cell的初始状态
            , dtype=None
            , parallel_iterations=None
            , swap_memory=False, time_major=False, scope=None)
        API返回值:(outputs, output_states) => outputs存储网络的输出信息,output_states存储网络的细胞状态信息
        outputs: 是一个二元组, (output_fw, output_bw)构成,output_fw对应前向的rnn_cell的执行结果,结构为:[batch_size, time_steps, output_size];output_bw对应反向的rnn_cell的执行结果,结果和output_bw一样
        output_states:是一个二元组,(output_state_fw, output_state_bw) 构成,output_state_fw和output_state_bw是dynamic_rnn API输出的状态值信息
        """
        outputs, states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=lstm_cell_fw, cell_bw=gru_cell_bw, inputs=X,
            initial_state_fw=init_state_fw, initial_state_bw=init_state_bw)
        output_fw = outputs[0][:, -1, :]
        output_bw = outputs[1][:, -1, :]
        output = tf.concat([output_fw, output_bw], 1)

    深度RNN

    Deep Bidirectional RNN(深度双向RNN)类似Bidirectional RNN,区别在于每一步的输入有多层网络,这样的话该网络便具有更加强大的表达能力和学习 能力,但是复杂性也提高了,同时需要训练更多的数据。

     

     深度RNN网络构建的代码如下:
    #多层
        def lstm_call():
            cell = tf.nn.rnn_cell.LSTMCell(num_units=hidden_size,reuse=tf.get_variable_scope().reuse)
            return tf.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=keep_prob)
        mlstm_cell = tf.nn.rnn_cell.MultiRNNCell(cells=[lstm_call() for i in range(layer_num)])
        inint_state = mlstm_cell.zero_state(batch_size,tf.float32)
        output,state = tf.nn.dynamic_rnn(mlstm_cell,inputs=X,initial_state=inint_state)
        output = output[:,-1,:]
  • 相关阅读:
    CodeForces 626 DIV.2 D Present
    PageRank 算法初步了解
    LeetCode 329. Longest Increasing Path in a Matrix(DFS,记忆化搜索)
    LeetCode 312. Burst Balloons(DP)
    LeetCode Contest 180
    用js来实现那些数据结构12(散列表)
    用js来实现那些数据结构11(字典)
    用js来实现那些数据结构10(集合02-集合的操作)
    用js来实现那些数据结构09(集合01-集合的实现)
    用js来实现那些数据结构08(链表02-双向链表)
  • 原文地址:https://www.cnblogs.com/limingqi/p/12638664.html
Copyright © 2011-2022 走看看