zoukankan      html  css  js  c++  java
  • pytorch-LSTM()简单介绍

    参考:https://www.cnblogs.com/duye/p/9913386.html

    torch.nn.LSTM()的参数:

    • input_size: 输入特征维数
    • hidden_size: 隐层状态的维数
    • num_layers: 层数
    • bias: 隐层状态是否带bias,默认为true
    • batch_first: 是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,通常设为True
    • dropout: 是否在除最后一个RNN层外的RNN层后面加dropout层
    • bidirectional: 是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

    LSTM输出:

    • out, hidden = lstm(input, hidden)
    • out的维度:(batch_size, seq_len, hidden_size)
    • hidden=(hn, cn)
    • hn和cn的维度:(num_layers*num_directions, batch_size, hidden_size)

    通常定义一个1层双向LSTM模型如下:(NER任务,输出是对每个字分类)

    import torch
    import torch.nn as nn
    
    
    class MyBiLSTM(nn.Module):
        def __init__(self, vocab_size, embedding_dim, hidden_dim, tag_size):
            super(MyBiLSTM, self).__init__()
            self.vocab_size = vocab_size
            self.embedding_dim = embedding_dim
            self.hidden_dim = hidden_dim
            self.tag_size = tag_size
            self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
            self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
            # *2因为是双向的
            self.output_layer = nn.Linear(hidden_dim * 2, tag_size)
    
        def init_hidden(self, batch_size):
            # 正太分布, num_layer=2, 因为是双向的
            return (torch.randn(2, batch_size, self.hidden_dim), torch.randn(2, batch_size, self.hidden_dim))
    
        def forward(self, input):
            self.hidden = self.init_hidden(len(input))
            embeds = self.embedding_layer(input)
            lstm_out, self.hidden = self.lstm_layer(embeds, self.hidden)
            model_out = self.output_layer(lstm_out)
            return model_out
    
    model = MyBiLSTM(5, 10, 10, 4)  # 5x10x10x4的LSTM
    input = torch.tensor([[0, 1, 2, 3, 4]]) # 输入5个字符
    print(model(input)) # 输出5x4的矩阵
  • 相关阅读:
    算法演示工具
    1198:逆波兰表达式
    1315:【例4.5】集合的划分
    1192:放苹果
    1191:流感传染
    1354括弧匹配检验
    1331【例1-2】后缀表达式的值
    1307高精度乘法
    1162字符串逆序
    1161转进制
  • 原文地址:https://www.cnblogs.com/mingriyingying/p/13381578.html
Copyright © 2011-2022 走看看