zoukankan      html  css  js  c++  java
  • 动手学PyTorch版RNN

    动手学PyTorch版RNN报错
    RuntimeError: expected scalar type Float but found Long
    把源代码中的RNNModel修改:
    原代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = to_onehot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(torch.stack(X), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    

    更改后的代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = F.one_hot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(X.float(), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    
  • 相关阅读:
    TP实例化模型的两种方式 M() D()
    implode 函数 把数组拼接成字符串
    用array_search 数组中查找是否存在这个 值
    SVN-001
    PHP-006
    Access数据操作-02
    Access数据操作-01
    Html解析
    浏览器Chrome对WebGL支持判断
    浏览器渲染模式设置
  • 原文地址:https://www.cnblogs.com/JinZL/p/14142713.html
Copyright © 2011-2022 走看看