zoukankan      html  css  js  c++  java
  • 动手学PyTorch版RNN

    动手学PyTorch版RNN报错
    RuntimeError: expected scalar type Float but found Long
    把源代码中的RNNModel修改:
    原代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = to_onehot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(torch.stack(X), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    

    更改后的代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = F.one_hot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(X.float(), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    
  • 相关阅读:
    python 基础2.5 循环中continue与breake用法
    python 基础 2.4 while 循环
    python 基础 2.3 for 循环
    python 基础 2.2 if流程控制(二)
    python 基础 2.1 if 流程控制(一)
    python 基础 1.6 python 帮助信息及数据类型间相互转换
    python 基础 1.5 python数据类型(四)--字典常用方法示例
    Tornado Web 框架
    LinkCode 第k个排列
    LeetCode 46. Permutations
  • 原文地址:https://www.cnblogs.com/JinZL/p/14142713.html
Copyright © 2011-2022 走看看