zoukankan      html  css  js  c++  java
  • 动手学PyTorch版RNN

    动手学PyTorch版RNN报错
    RuntimeError: expected scalar type Float but found Long
    把源代码中的RNNModel修改:
    原代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = to_onehot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(torch.stack(X), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    

    更改后的代码

    class RNNModel(nn.Module):
        def __init__(self, rnn_layer, vocab_size):
            super(RNNModel, self).__init__()
            self.rnn = rnn_layer
            self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
            self.vocab_size = vocab_size
            self.dense = nn.Linear(self.hidden_size, vocab_size)
            self.state = None
    
        def forward(self, inputs, state): # inputs: (batch, seq_len)
            # 获取one-hot向量表示
            X = F.one_hot(inputs, self.vocab_size) # X是个list
            Y, self.state = self.rnn(X.float(), state)
            # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
            # 形状为(num_steps * batch_size, vocab_size)
            output = self.dense(Y.view(-1, Y.shape[-1]))
            return output, self.state
    
  • 相关阅读:
    sqli-labs lexx25-28a(各种过滤)
    sqli-labs less-24(二次注入)
    sqli-labs less13-20(各种post型头部注入)
    sql注入之双查询注入
    sqli-labs less11-12(post型union注入)
    sqli-labs less8-10(布尔盲注时间盲注)
    sqli-labs less-7(文件读写)
    Vue ref childNode 坑
    Blob
    中文输入法不触发onkeyup事件的解决办法Script
  • 原文地址:https://www.cnblogs.com/JinZL/p/14142713.html
Copyright © 2011-2022 走看看