zoukankan      html  css  js  c++  java
  • RNN、lstm、gru详解

    一、RNN

    RNN结构:

    RNN的结构是由一个输入层、隐藏层、输出层组成:

     将RNN的结构按照时间序列展开

    其中$U_{t-1}、U_{t}、U_{t+1}$三者是同一个值,只是按着时刻称呼不一样而已,对应的W和V也是一样。

    对应的前向传播公式和对应的每个时刻的输出公式

    $S_{t-1}=U_{t-1}X_{t-1}+W_{t-1}S_{t-2}+b_1 qquad y_{t-1}=V_{t-1}S_{t-1}+b_2$

    $S_{t}=U_{t}X_{t}+W_{t}S_{t-1}+b_1 qquad qquad y_{t}=V_{t}S_{t}+b_2$

    $S_{t+1}=U_{t+1}X_{t+1}+W_{t+1}S_{t}+b_1 qquad y_{t+1}=V_{t+1}S_{t+1}+b_2$ 

    二、LSTM(Long Short-Term Memory,长短期记忆网络) 

    LSTM是一种特殊的RNN类型,一般的RNN结构如下图所示,是一种将以往学习的结果应用到当前学习的模型,但是这种一般的RNN存在着许多的弊端。举个例子,如果我们要预测“the clouds are in the sky”的最后一个单词,因为只在这一个句子的语境中进行预测,那么将很容易地预测出是这个单词是sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以学会使用先前的信息。 

     

    标准的RNN结构中只有一个神经元,一个tanh层进行重复的学习,这样会存在一些弊端。例如,在比较长的环境中,例如在“I grew up in France… I speak fluent French”中去预测最后的French,那么模型会推荐一种语言的名字,但是预测具体是哪一种语言时就需要用到很远以前的Franch,这就说明在长环境中相关的信息和预测的词之间的间隔可以是非常长的。在理论上,RNN 绝对可以处理这样的长环境问题。人们可以仔细挑选参数来解决这类问题中的最初级形式,但在实践中,RNN 并不能够成功学习到这些知识。然而,LSTM模型就可以解决这一问题

    如图所示,标准LSTM模型是一种特殊的RNN类型,在每一个重复的模块中有四个特殊的结构,以一种特殊的方式进行交互。在图中,每一条黑线传输着一整个向量,粉色的圈代表一种pointwise 操作(将定义域上的每一点的函数值分别进行运算),诸如向量的和,而黄色的矩阵就是学习到的神经网络层。 
    LSTM模型的核心思想是“细胞状态”。“细胞状态”类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。 

    LSTM 有通过精心设计的称作为“门”的结构来去除或者增加信息到细胞状态的能力。门是一种让信息选择式通过的方法。他们包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作。 
    LSTM门 
    Sigmoid 层输出 0 到 1 之间的数值,描述每个部分有多少量可以通过。0 代表“不许任何量通过”,1 就指“允许任意量通过”。LSTM 拥有三个门,来保护和控制细胞状态。 

    在LSTM模型中,第一步是决定我们从“细胞”中丢弃什么信息,这个操作由一个忘记门层来完成。该层读取当前输入x和前神经元信息h,由ft来决定丢弃的信息。输出结果1表示“完全保留”,0 表示“完全舍弃”。 

    第二步是确定细胞状态所存放的新信息,这一步由两层组成。sigmoid层作为“输入门层”,决定我们将要更新的值i;tanh层来创建一个新的候选值向量~Ct加入到状态中。在语言模型的例子中,我们希望增加新的主语到细胞状态中,来替代旧的需要忘记的主语。 
     

    第三步就是更新旧细胞的状态,将Ct-1更新为Ct。我们把旧状态与 ft相乘,丢弃掉我们确定需要丢弃的信息。接着加上 it * ~Ct。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。在语言模型的例子中,这就是我们实际根据前面确定的目标,丢弃旧代词的信息并添加新的信息的地方。 


    最后一步就是确定输出了,这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。在语言模型的例子中,因为语境中有一个代词,可能需要输出与之相关的信息。例如,输出判断是一个动词,那么我们需要根据代词是单数还是负数,进行动词的词形变化。

    三、GRU(Gated Recurrent Unit, LSTM变体)


    GRU作为LSTM的一种变体,将忘记门和输入门合成了一个单一的更新门。同样还混合了细胞状态和隐藏状态,加诸其他一些改动。最终的模型比标准的 LSTM 模型要简单,也是非常流行的变体。

    四、对比

    LSTM与GRU对比 
    概括的来说,LSTM和GRU都能通过各种Gate将重要特征保留,保证其在long-term 传播的时候也不会被丢失。 
    结果对比1 
    可以看出,标准LSTM和GRU的差别并不大,但是都比tanh要明显好很多,所以在选择标准LSTM或者GRU的时候还要看具体的任务是什么。 
    使用LSTM的原因之一是解决RNN Deep Network的Gradient错误累积太多,以至于Gradient归零或者成为无穷大,所以无法继续进行优化的问题。GRU的构造更简单:比LSTM少一个gate,这样就少几个矩阵乘法。在训练数据很大的情况下GRU能节省很多时间。

    五、LSTM具体程序示例

     经常编码方式有两种:

    1.lstm

    #coding=utf-8
    import torch
    import torch.nn as nn
    
    '''
    假如输入有3个句子,每个句子都由5个单词组成,每个单词用10维的词向量表示,
    则batch=3, seq_len=5, Embedding=10.
    '''
    
    #设置LSTM的参数;词向量维数10,隐藏元维度20,2个LSTM层串联,双向lstm
    bilstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
    
    
    #如下表示输入的句子
    input = torch.randn(5, 3, 10)#[seq_len,batch_size,Embedding]
    #初始化的隐藏元和记忆元,通常它们是维度是一样的
    h0 = torch.randn(4, 3, 20)#[bidirection*num_layers,batch_size,hidden_size]
    c0 = torch.randn(4, 3, 20)#[bidirection*num_layers,batch_size,hidden_size]
    
    #这里有2层lstm,output是最后一层lstm的每个词向量对应隐藏层的输出,与层数无关,只与序列长度相关
    #hn,cn是所有层最后一个隐藏元和记忆元的输出
    output, (hn, cn) = bilstm(input, (h0, c0))
    print('output shape: ', output.shape)#output shape:  torch.Size([5, 3, 40]),[seq_len,batch_size,2*hidden_size]
    print('hn shape: ', hn.shape)#hn shape:  torch.Size([4, 3, 20]),[bidirection*num_layers,batch_size,hidden_size]
    print('cn shape: ', cn.shape)#cn shape:  torch.Size([4, 3, 20]),[bidirection*num_layers,batch_size,hidden_size]
    
    #将输出的数据做一个二分类
    output=output.permute(1,0,2)#torch.Size([3, 5, 40]),[batch_size,seq_len,2*hidden_size]
    output=output.contiguous()
    batch_size=output.size(0)
    output=output.view(batch_size,-1)#torch.Size([3, 200]),[batch_size,2*seq_len*hidden_size]
    fully_connected=nn.Linear(200,2)
    output=fully_connected(output)
    print(output.shape)#torch.Size([3, 2]),[batch_size,class]

    2.LSTMCell

    需要自己定义循环的次数

    import torch
    import torch.nn as nn
    '''
    函数参数torch.nn.LSTMCell(input_size, hidden_size, bias=True)
    '''
    #其中input_size是5,hidden_size是10
    lstm_cell = nn.LSTMCell(5,10)
    
    #初始化参数
    input = torch.randn(2,5)#[batch_size,input_size]
    h = torch.randn(2,10)#[batch_size,hidden_size]
    c = torch.randn(2,10)#[batch_size,hidden_size]
    #自定义循环次数
    T=10
    for time_step in range(T):
        h,c = lstm_cell(input,(h,c))
    print(h.shape,c.shape)#torch.Size([2, 10]) torch.Size([2, 10])
    
    
    #一个具体的应用例子,将lstm的输出做一个二分类
    #对输出的数据,做一个fully_connected
    fully_connected=nn.Linear(10,2)
    output=fully_connected(h)#[batch_size,class]
    print(output.size())#torch.Size([2, 2])

    参考链接

    https://zhuanlan.zhihu.com/p/30844905

    https://blog.csdn.net/lreaderl/article/details/78022724

  • 相关阅读:
    ppt 制作圆角三角形
    ROS 错误之 [rospack] Error: package 'beginner_tutorials' not found
    ubuntu下安装搜狗输入法以及出现不能输入中文的解决办法
    <crtdbg.h> 的作用
    mybatis配置
    POJO、Bean和JavaBean
    类类型与反射
    Spring层面的事务管理
    java项目常用架构
    java 遍历数组的几种方式
  • 原文地址:https://www.cnblogs.com/AntonioSu/p/8798960.html
Copyright © 2011-2022 走看看