zoukankan      html  css  js  c++  java
  • torch.nn.LSTM()函数维度详解

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    lstm=nn.LSTM(input_size,                     hidden_size,                      num_layers)
    x                         seq_len,                          batch,                              input_size
    h0            num_layers× imes×num_directions,   batch,                             hidden_size
    c0            num_layers× imes×num_directions,   batch,                             hidden_size

    output                 seq_len,                         batch,                num_directions× imes×hidden_size
    hn            num_layers× imes×num_directions,   batch,                             hidden_size
    cn            num_layers× imes×num_directions,    batch,                            hidden_size

    举个例子:
    对句子进行LSTM操作

    假设有100个句子(sequence),每个句子里有7个词,batch_size=64,embedding_size=300

    此时,各个参数为:
    input_size=embedding_size=300
    batch=batch_size=64
    seq_len=7

    另外设置hidden_size=100, num_layers=1

    import torch
    import torch.nn as nn
    lstm = nn.LSTM(300, 100, 1)
    x = torch.randn(7, 64, 300)
    h0 = torch.randn(1, 64, 100)
    c0 = torch.randn(1, 64, 100)
    output, (hn, cn)=lstm(x, (h0, c0))

    >>
    output.shape torch.Size([7, 64, 100])
    hn.shape torch.Size([1, 64, 100])
    cn.shape torch.Size([1, 64, 100])
    ---------------------
    作者:huxuedan01
    来源:CSDN
    原文:https://blog.csdn.net/m0_37586991/article/details/88561746
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    P1828 [USACO3.2]香甜的黄油 Sweet Butter 题解
    P2058 海港 题解
    浅谈三分算法
    海伦公式的证明
    一年一回首
    再谈单调队列优化 & 背包九讲
    浅谈单调队列
    P1440 求m区间内的最小值 题解
    CF1374B Multiply by 2, divide by 6 题解
    组合数、杨辉三角与递推算法
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11187387.html
Copyright © 2011-2022 走看看