zoukankan      html  css  js  c++  java
  • 动手学PyTorch版RNN

    动手学PyTorch版RNN报错
    RuntimeError: expected scalar type Float but found Long
    把源代码中的RNNModel修改:
    原代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = to_onehot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(torch.stack(X), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    

    更改后的代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = F.one_hot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(X.float(), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    
  • 相关阅读:
    keyCode的使用
    写自已的类库需要的核心代码
    50个必备的实用jQuery代码段
    javascript基础
    给js原生Array增加each方法
    jquery中一些容易让人困惑的东西总结[转载]
    ajax编程
    oracle的正则表达式 [转载]
    eclipse 插件大全
    SlickGrid Options
  • 原文地址:https://www.cnblogs.com/JinZL/p/14142713.html
Copyright © 2011-2022 走看看