zoukankan      html  css  js  c++  java
  • torch.nn.LSTM()函数维度详解

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    lstm=nn.LSTM(input_size,                     hidden_size,                      num_layers)
    x                         seq_len,                          batch,                              input_size
    h0            num_layers× imes×num_directions,   batch,                             hidden_size
    c0            num_layers× imes×num_directions,   batch,                             hidden_size

    output                 seq_len,                         batch,                num_directions× imes×hidden_size
    hn            num_layers× imes×num_directions,   batch,                             hidden_size
    cn            num_layers× imes×num_directions,    batch,                            hidden_size

    举个例子:
    对句子进行LSTM操作

    假设有100个句子(sequence),每个句子里有7个词,batch_size=64,embedding_size=300

    此时,各个参数为:
    input_size=embedding_size=300
    batch=batch_size=64
    seq_len=7

    另外设置hidden_size=100, num_layers=1

    import torch
    import torch.nn as nn
    lstm = nn.LSTM(300, 100, 1)
    x = torch.randn(7, 64, 300)
    h0 = torch.randn(1, 64, 100)
    c0 = torch.randn(1, 64, 100)
    output, (hn, cn)=lstm(x, (h0, c0))

    >>
    output.shape torch.Size([7, 64, 100])
    hn.shape torch.Size([1, 64, 100])
    cn.shape torch.Size([1, 64, 100])
    ---------------------
    作者:huxuedan01
    来源:CSDN
    原文:https://blog.csdn.net/m0_37586991/article/details/88561746
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    Spring 事务管理tx,aop
    好的博客参考之Spring
    Spring 事务管理
    Eclipse+Tomcat+MAVEN+SVN项目完整环境搭建
    ssm框架搭建
    SSH整合不错的博客
    org.springframework.beans.factory.CannotLoadBeanClassException: Cannot find class [com.my.service.ProductService] for bean with name 'productService' defi报错解决方法
    修改firefox的默认缩放比
    深入理解计算机系统笔记
    OnePlus5刷机后一直检查更新
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11187387.html
Copyright © 2011-2022 走看看