zoukankan      html  css  js  c++  java
  • PyTorch-LSTM

     1 import torch
     2 import torch.nn as nn
     3 
     4 torch.random.manual_seed(10)
     5 
     6 input_size = 2  # 输入向量维度
     7 hidden_size = 4 # 隐层层维度
     8 num_layers = 2 # 层数
     9 
    10 lstm = nn.LSTM(input_size, hidden_size, num_layers)
    11 
    12 
    13 # Input:
    14 
    15 # input of shape (sep_len, bath, input_size)
    16 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size)
    17 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size)
    18 
    19 # Output:
    20 # output of shape (sep_len, bath, num_directions * hidden_size)
    21 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size)
    22 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size)
    23 
    24 # two ways
    25 Input = torch.randn(4, 3, 2)
    26 h = torch.randn(2, 3, 4)
    27 c = torch.randn(2, 3, 4)
    28 output = None
    29 
    30 # first
    31 h1 = h
    32 c1 = c
    33 for it in Input:
    34     output, (h1, c1) = lstm(it.view(1, 3, -1), (h1, c1))
    35     print((output == h1[-1]).all().item())
    36 print(output)
    37 
    38 # second
    39 output1, (h, c) = lstm(Input,(h, c))
    40 print(output1[-1])
    41 # print(output1[-1] == output) 精度的问题
  • 相关阅读:
    随笔1
    随笔
    shared_ptr<> reset
    c++模板库(简介)
    rockmongo用法
    随笔
    TEXT宏,TCHAR类型
    sprintf
    基于SOA的银行系统架构
    大纲6 信息化规划与管理
  • 原文地址:https://www.cnblogs.com/xidian-mao/p/12112858.html
Copyright © 2011-2022 走看看