zoukankan      html  css  js  c++  java
  • 动手学PyTorch版RNN

    动手学PyTorch版RNN报错
    RuntimeError: expected scalar type Float but found Long
    把源代码中的RNNModel修改:
    原代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = to_onehot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(torch.stack(X), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    

    更改后的代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = F.one_hot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(X.float(), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    
  • 相关阅读:
    php 数组分页
    Fchart
    thinkphp对数据库操作有哪些内置函数
    MySQL性能优化的最佳20+条经验
    apache 简单笔记
    PHPMyadmin 配置文件详解(配置)
    mysql 常用知识
    分布式微服务日志的配置
    分布式微服务的配置
    分布式接口的调用
  • 原文地址:https://www.cnblogs.com/JinZL/p/14142713.html
Copyright © 2011-2022 走看看