zoukankan      html  css  js  c++  java
  • 《PyTorch深度学习实践》第12集

    问题:在查看刘老师的《PyTorch深度学习实践》第十二集 时,发现改用embedding的方式时,维度报错,然后稍微改了点代码(不知是否正确,还望指教)

    资料:1、RNN ; 2、Embedding

    num_class = 4
    input_size = 4
    hidden_size = 8
    embedding_size = 10
    num_layers = 2
    batch_size = 1
    # seq_len = 5
    
    idx2char = ['e', 'h', 'l', 'o']
    x_data = [1, 0, 2, 2, 3]
    y_data = [3, 1, 2, 3, 2]
    
    inputs = torch.LongTensor(x_data)
    labels = torch.LongTensor(y_data)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            
            self.emb = torch.nn.Embedding(input_size, embedding_size)
            # If True, then the input and output tensors are provided as (batch, seq, feature). 
            self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
            self.fc = torch.nn.Linear(hidden_size, num_class)
        
        def forward(self, x):
            hidden = torch.zeros(num_layers, batch_size, hidden_size)  # 这里也修改了
            x = self.emb(x)  # (seqlen, embedding_size)
            x = x.unsqueeze(0)  # 扩充一个维度batch:(batch, seqlen, embedding_size)
            x, _ = self.rnn(x, hidden)
            x = self.fc(x)
            return x.view(-1, num_class)
    
    
    net = Model()
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.05)
    
    for epoch in range(15):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, idx = outputs.max(dim=1)
        idx = idx.data.numpy()
        print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
        print(', Epoch [%d/15] loss = %.4f' % (epoch+1, loss.item()))

    然后,输出结果如下:

    注:记录一下。警惕以后用到神经网络时,一定要记得各种dimension size的变化情况!

  • 相关阅读:
    设置全屏的方法
    The connection to adb is down,and a server error has occured.解决办法---------------------亲测有效
    android 案例二 登录界面
    javaweb项目编译错误
    Ubuntu 14.04 tomcat配置
    Ubuntu 14.03 安装jdk
    Ubuntu 14.03 安装mysql
    Git 版本管理使用说明。
    getColor问题
    WebView 调试
  • 原文地址:https://www.cnblogs.com/heyour/p/13474800.html
Copyright © 2011-2022 走看看