zoukankan      html  css  js  c++  java
  • Pytorch-LSTM

    1.nn.LSTM

    1.1lstm=nn.LSTM(input_size, hidden_size, num_layers)

    参数:

    • input_size:输入特征的维度, 一般rnn中输入的是词向量,那么 input_size 就等于一个词向量的维度,即feature_len;
    • hidden_size:隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态);
    • num_layers:网络的层数;

    1.2out, (ht, ct) = lstm(x, [ht0, ct0])

    • x:[seq_len, batch, feature_len]
    • h/c:[num_layer, batch, hidden_len]
    • out:[seq_len, batch, hidden_len]
     1 import torch
     2 from torch import nn
     3 
     4 lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)   #4层的LSTM,输入的每个词用100维向量表示,隐藏单元和记忆单元的尺寸是20
     5 
     6 x = torch.randn(10, 3, 100)    #3句话,每句10个单词,每个单词表示为长100的向量
     7 out, (h, c) = lstm(x)          #不传入h_0和c_0则会默认初始化
     8 print(out.shape)               #torch.Size([10, 3, 20]) 
     9 print(h.shape)                 #torch.Size([4, 3, 20]) 
    10 print(c.shape)                 #torch.Size([4, 3, 20]) 

    2.nn.LSTMCell

    nn.LSTMCell与nn.LSTM的区别和nn.RNN与nn.RNNCell的区别一样。

    2.1nn.LSTMCell()

    初始化方法和上面一样。

    2.2ht, ct = lstmcell(xt, [ht-1, ct-1])

    • xt:[batch, feature_len]表示t时刻的输入
    • ht-1, ct-1:[batch, hidden_len] t-1时刻本层的隐藏单元和记忆单元
     1 #单层LSTM
     2 import torch
     3 from torch import nn
     4 
     5 cell = nn.LSTMCell(input_size=100, hidden_size=20)       #1层的LSTM,输入的每个词用100维向量表示,隐藏单元和记忆单元的尺寸是20
     6 
     7 h = torch.zeros(3, 20)                                   #初始化隐藏单元h和记忆单元c,取batch=3
     8 c = torch.zeros(3, 20)
     9 
    10 x = [torch.randn(3, 100) for _ in range(10)]             #seq_len=10个时刻的输入,每个时刻shape都是[batch,feature_len]
    11 
    12 for xt in x:                                             #对每个时刻,传入输入xt和上个时刻的h和c
    13     h, c = cell(xt, (h, c))
    14 
    15 print(h.shape,c.shape)                                   #torch.Size([3, 20]) torch.Size([3, 20])
    16 
    17 
    18 #两层LSTM
    19 cell_l0 = nn.LSTMCell(input_size=100, hidden_size=30)    #输入的feature_len=100,变到该层隐藏单元和记忆单元hidden_len=30
    20 cell_l1 = nn.LSTMCell(input_size=30, hidden_size=20)     #hidden_len从l0层的30变到这一层的20
    21 
    22 h_l0 = torch.zeros(3, 30)           #分别初始化l0层和l1层的隐藏单元h和记忆单元C,取batch=3
    23 C_l0 = torch.zeros(3, 30)
    24 
    25 h_l1 = torch.zeros(3, 20)
    26 C_l1 = torch.zeros(3, 20)
    27 
    28 x = [torch.randn(3, 100) for _ in range(10)]             #seq_len=10个时刻的输入,每个时刻shape都是[batch,feature_len]
    29 
    30 for xt in x:
    31     h_l0, C_l0 = cell_l0(xt, (h_l0, C_l0))               #l0层接受xt输入
    32     h_l1, C_l1 = cell_l1(h_l0, (h_l1, C_l1))             #l1层接受l0层的输出h作为输入
    33 
    34 print(h_l0.shape, C_l0.shape)                            #torch.Size([3, 30]) torch.Size([3, 30]) 
    35 print(h_l1.shape, C_l1.shape)                            #torch.Size([3, 20]) torch.Size([3, 20])
  • 相关阅读:
    centos7系统初始化
    瀑布流无限加载infinitescroll插件与masonry插件使用
    网页前端导出CSV,Excel格式文件
    js添加收藏夹
    Fiddler修改http请求响应简单实例
    Nodejs的Gruntjs使用一则
    Jquery插件validate使用一则
    PostgreSQL操作-psql基本命令
    SSH连接时出现Host key verification failed的原因及解决方法以及ssh-keygen命令的用法
    在ubuntu20.04上设置python2为默认方式
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13355203.html
Copyright © 2011-2022 走看看