zoukankan      html  css  js  c++  java
  • Pytorch-基础入门之LSTM

    学习Pytorch的目的就是用LSTM来对舆情的数据进行处理,之后那个项目全部做好会发布出来。LSTM也是很经典的网络了,一种RNN网络,在这里也不做赘述了。

    某型的一些说明:

    hidden layer dimension is 100
    number of hidden layer is 1
    

    这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。

    第一部分:构造模型

    # Import Libraries
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    
    class LSTMModel(nn.Module):
        def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
            super(LSTMModel, self).__init__()
            
            # Hidden dimensions
            self.hidden_dim = hidden_dim
    
            # Number of hidden layers
            self.layer_dim = layer_dim
    
            # LSTM
            self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True) # batch_first=True (batch_dim, seq_dim, feature_dim)
    
            # Readout layer
            self.fc = nn.Linear(hidden_dim, output_dim)
    
        def forward(self, x):
            # Initialize hidden state with zeros
            h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
    
            # Initialize cell state
            c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
    
            # 28 time steps
            # We need to detach as we are doing truncated backpropagation through time (BPTT)
            # If we don't, we'll backprop all the way to the start even after going through another batch
            out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
    
            # Index hidden state of last time step
            # out.size() --> 100, 28, 100
            # out[:, -1, :] --> 100, 100 --> just want last time step hidden states! 
            out = self.fc(out[:, -1, :]) 
            # out.size() --> 100, 10
            return out
        
    input_dim = 28
    hidden_dim = 100
    layer_dim = 1
    output_dim = 10
    model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
    
    error = nn.CrossEntropyLoss()
    
    learning_rate = 0.1
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 
    

      第二部分:训练模型

    # Number of steps to unroll
    seq_dim = 28  
    loss_list = []
    iteration_list = []
    accuracy_list = []
    count = 0
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # Load images as a torch tensor with gradient accumulation abilities
            images = images.view(-1, seq_dim, input_dim).requires_grad_()
    
            # Clear gradients w.r.t. parameters
            optimizer.zero_grad()
    
            # Forward pass to get output/logits
            # outputs.size 100, 10
            outputs = model(images)
    
            # Calculate Loss: softmax --> cross entropy loss
            loss = error(outputs, labels)
    
            # Getting gradients
            loss.backward()
    
            # Updating parameters
            optimizer.step()
    
            count += 1
    
            if count % 500 == 0:
                # Calculate Accuracy         
                correct = 0
                total = 0
                for images, labels in test_loader:
                    
                    images = images.view(-1, seq_dim, input_dim)
    
                    # Forward pass only to get logits/output
                    outputs = model(images)
    
                    # Get predictions from the maximum value
                    _, predicted = torch.max(outputs.data, 1)
    
                    # Total number of labels
                    total += labels.size(0)
    
                    # Total correct predictions
                    correct += (predicted == labels).sum()
    
                accuracy = 100 * correct / total
                
                loss_list.append(loss.data.item())
                iteration_list.append(count)
                accuracy_list.append(accuracy)
                
                # Print Loss
                print('Iteration: {}. Loss: {}. Accuracy: {}'.format(count, loss.data.item(), accuracy))
    

     结果:

    Iteration: 500. Loss: 2.2601425647735596. Accuracy: 19
    Iteration: 1000. Loss: 0.9044000506401062. Accuracy: 71
    Iteration: 1500. Loss: 0.33562779426574707. Accuracy: 88
    Iteration: 2000. Loss: 0.29831066727638245. Accuracy: 92
    Iteration: 2500. Loss: 0.20772598683834076. Accuracy: 94
    Iteration: 3000. Loss: 0.13703776895999908. Accuracy: 95
    Iteration: 3500. Loss: 0.1824885755777359. Accuracy: 95
    Iteration: 4000. Loss: 0.021043945103883743. Accuracy: 96
    Iteration: 4500. Loss: 0.13939177989959717. Accuracy: 96
    Iteration: 5000. Loss: 0.032742198556661606. Accuracy: 96
    Iteration: 5500. Loss: 0.1308797001838684. Accuracy: 96
    

     第三部分:可视化展示

    # visualization loss 
    plt.plot(iteration_list,loss_list)
    plt.xlabel("Number of iteration")
    plt.ylabel("Loss")
    plt.title("LSTM: Loss vs Number of iteration")
    plt.show()
    
    # visualization accuracy 
    plt.plot(iteration_list,accuracy_list,color = "red")
    plt.xlabel("Number of iteration")
    plt.ylabel("Accuracy")
    plt.title("LSTM: Accuracy vs Number of iteration")
    plt.savefig('graph.png')
    plt.show()
    

     结果:

     
     
     

     

  • 相关阅读:
    1052 Linked List Sorting (25 分)
    1051 Pop Sequence (25 分)
    1050 String Subtraction (20 分)
    1049 Counting Ones (30 分)
    1048 Find Coins (25 分)
    1047 Student List for Course (25 分)
    1046 Shortest Distance (20 分)
    1045 Favorite Color Stripe (30 分)
    1044 Shopping in Mars (25 分)
    1055 The World's Richest (25 分)
  • 原文地址:https://www.cnblogs.com/zhuozige/p/14696784.html
Copyright © 2011-2022 走看看