zoukankan      html  css  js  c++  java
  • 关于torch.nn.LSTM()的输入和输出

    主角torch.nn.LSTM()

    初始化时要传入的参数

     |  Args:
     |      input_size: The number of expected features in the input `x`
     |      hidden_size: The number of features in the hidden state `h`
     |      num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
     |          would mean stacking two LSTMs together to form a `stacked LSTM`,
     |          with the second LSTM taking in outputs of the first LSTM and
     |          computing the final results. Default: 1
     |      bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
     |          Default: ``True``
     |      batch_first: If ``True``, then the input and output tensors are provided
     |          as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
     |          Note that this does not apply to hidden or cell states. See the
     |          Inputs/Outputs sections below for details.  Default: ``False``
     |      dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
     |          LSTM layer except the last layer, with dropout probability equal to
     |          :attr:`dropout`. Default: 0
     |      bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
     |      proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
    

    input_size:一般是词嵌入的大小
    hidden_size:隐含层的维度
    num_layers:默认是1,单层LSTM
    bias:是否使用bias
    batch_first:默认为False,如果设置为True,则表示第一个维度表示的是batch_size
    dropout:直接看英文吧
    bidirectional:默认为False,表示单向LSTM,当设置为True,表示为双向LSTM,一般和num_layers配合使用(需要注意的是当该项设置为True时,将num_layers设置为1,表示由1个双向LSTM构成)

    模型输入输出-单向LSTM

    import torch
    import torch.nn as nn
    import numpy as np
    
    inputs_numpy = np.random.random((64,32,300))
    inputs = torch.from_numpy(inputs_numpy).to(torch.float32)
    inputs.shape
    

    torch.Size([64, 32, 300]):表示[batchsize, max_length, embedding_size]

    hidden_size = 128
    lstm = nn.LSTM(300, 128, batch_first=True, num_layers=1)
    output, (hn, cn) = lstm(inputs)
    print(output.shape)
    print(hn.shape)
    print(cn.shape)
    

    torch.Size([64, 32, 128])
    torch.Size([1, 64, 128])
    torch.Size([1, 64, 128])
    说明:
    output:保存了每个时间步的输出,如果想要获取最后一个时间步的输出,则可以这么获取:output_last = output[:,-1,:]
    h_n:包含的是句子的最后一个单词的隐藏状态,与句子的长度seq_length无关
    c_n:包含的是句子的最后一个单词的细胞状态,与句子的长度seq_length无关
    另外:最后一个时间步的输出等于最后一个隐含层的输出

    output_last = output[:,-1,:]
    hn_last = hn[-1]
    print(output_last.eq(hn_last))
    

    image

    模型输入输出-双向LSTM

    首先我们要明确:
    output :(seq_len, batch, num_directions * hidden_size)
    h_n:(num_layers * num_directions, batch, hidden_size)
    c_n :(num_layers * num_directions, batch, hidden_size)
    其中num_layers表示层数,这里是1,num_directions表示方向数,由于是双向的,这里是2,也是,我们就有下面的结果:

    import torch
    import torch.nn as nn
    import numpy as np
    
    inputs_numpy = np.random.random((64,32,300))
    inputs = torch.from_numpy(inputs_numpy).to(torch.float32)
    inputs.shape
    hidden_size = 128
    lstm = nn.LSTM(300, 128, batch_first=True, num_layers=1, bidirectional=True)
    output, (hn, cn) = lstm(inputs)
    print(output.shape)
    print(hn.shape)
    print(cn.shape)
    

    torch.Size([64, 32, 256])
    torch.Size([2, 64, 128])
    torch.Size([2, 64, 128])
    这里面的hn包含两个元素,一个是正向的隐含层输出,一个是方向的隐含层输出。

    #获取反向的最后一个output
    output_last_backward = output[:,0,-hidden_size:]
    #获反向最后一层的hn
    hn_last_backward = hn[-1]
     
    #反向最后的output等于最后一层的hn
    print(output_last_backward.eq(hn_last_backward))
     
    #获取正向的最后一个output
    output_last_forward = output[:,-1,:hidden_size]
    #获取正向最后一层的hn
    hn_last_forward = hn[-2]
    # 反向最后的output等于最后一层的hn
    print(output_last_forward.eq(hn_last_forward))
    

    image

    https://www.cnblogs.com/LiuXinyu12378/p/12322993.html
    https://blog.csdn.net/m0_45478865/article/details/104455978
    https://blog.csdn.net/foneone/article/details/104002372

  • 相关阅读:
    C#获取Word文档页数,并跳转到指定的页面获取页面信息
    GC 垃圾回收
    Open Flash Chart 之线图
    Open Flash Chart 之线图(二)
    Nullable可空类型
    System.AppDomain类
    C# 事件
    C#方法的参数 Ref Out Params 4种类型的参数
    单向链表
    C# 结构体 struct
  • 原文地址:https://www.cnblogs.com/xiximayou/p/15036715.html
Copyright © 2011-2022 走看看