zoukankan      html  css  js  c++  java
  • 《PyTorch深度学习实践》第12集

    问题:在查看刘老师的《PyTorch深度学习实践》第十二集 时,发现改用embedding的方式时,维度报错,然后稍微改了点代码(不知是否正确,还望指教)

    资料:1、RNN ; 2、Embedding

    num_class = 4
    input_size = 4
    hidden_size = 8
    embedding_size = 10
    num_layers = 2
    batch_size = 1
    # seq_len = 5
    
    idx2char = ['e', 'h', 'l', 'o']
    x_data = [1, 0, 2, 2, 3]
    y_data = [3, 1, 2, 3, 2]
    
    inputs = torch.LongTensor(x_data)
    labels = torch.LongTensor(y_data)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            
            self.emb = torch.nn.Embedding(input_size, embedding_size)
            # If True, then the input and output tensors are provided as (batch, seq, feature). 
            self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
            self.fc = torch.nn.Linear(hidden_size, num_class)
        
        def forward(self, x):
            hidden = torch.zeros(num_layers, batch_size, hidden_size)  # 这里也修改了
            x = self.emb(x)  # (seqlen, embedding_size)
            x = x.unsqueeze(0)  # 扩充一个维度batch:(batch, seqlen, embedding_size)
            x, _ = self.rnn(x, hidden)
            x = self.fc(x)
            return x.view(-1, num_class)
    
    
    net = Model()
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.05)
    
    for epoch in range(15):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, idx = outputs.max(dim=1)
        idx = idx.data.numpy()
        print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
        print(', Epoch [%d/15] loss = %.4f' % (epoch+1, loss.item()))

    然后,输出结果如下:

    注:记录一下。警惕以后用到神经网络时,一定要记得各种dimension size的变化情况!

  • 相关阅读:
    关于Spring的destroy-method和scope="prototype"不能共存问题
    关于引入文件名字问题
    技术学习路
    web.xml文件配置
    性能测试中的TPS与HPS
    设计模式简介
    Cause of 400 Bad Request Errors
    vim使用技巧
    如何更好地利用Pmd、Findbugs和CheckStyle分析结果
    Java基础知识
  • 原文地址:https://www.cnblogs.com/heyour/p/13474800.html
Copyright © 2011-2022 走看看