zoukankan      html  css  js  c++  java
  • Pytorch-基础入门之ANN

    在这部分中来介绍下ANN的Pytorch,这里的ANN具有三个隐含层。

    这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。

    第一部分:构造模型

    # Import Libraries
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    
    # Create ANN Model
    class ANNModel(nn.Module):
        
        def __init__(self, input_dim, hidden_dim, output_dim):
            super(ANNModel, self).__init__()
            
            # Linear function 1: 784 --> 150
            self.fc1 = nn.Linear(input_dim, hidden_dim) 
            # Non-linearity 1
            self.relu1 = nn.ReLU()
            
            # Linear function 2: 150 --> 150
            self.fc2 = nn.Linear(hidden_dim, hidden_dim)
            # Non-linearity 2
            self.tanh2 = nn.Tanh()
            
            # Linear function 3: 150 --> 150
            self.fc3 = nn.Linear(hidden_dim, hidden_dim)
            # Non-linearity 3
            self.elu3 = nn.ELU()
            
            # Linear function 4 (readout): 150 --> 10
            self.fc4 = nn.Linear(hidden_dim, output_dim)  
        
        def forward(self, x):
            # Linear function 1
            out = self.fc1(x)
            # Non-linearity 1
            out = self.relu1(out)
            
            # Linear function 2
            out = self.fc2(out)
            # Non-linearity 2
            out = self.tanh2(out)
            
            # Linear function 2
            out = self.fc3(out)
            # Non-linearity 2
            out = self.elu3(out)
            
            # Linear function 4 (readout)
            out = self.fc4(out)
            return out
    
    # instantiate ANN
    input_dim = 28*28
    hidden_dim = 150 #hidden layer dim is one of the hyper parameter and it should be chosen and tuned. For now I only say 150 there is no reason.
    output_dim = 10
    
    # Create ANN
    model = ANNModel(input_dim, hidden_dim, output_dim)
    
    # Cross Entropy Loss 
    error = nn.CrossEntropyLoss()
    
    # SGD Optimizer
    learning_rate = 0.02
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    

     第二部分:训练模型

    # ANN model training
    count = 0
    loss_list = []
    iteration_list = []
    accuracy_list = []
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
    
            train = Variable(images.view(-1, 28*28))
            labels = Variable(labels)
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward propagation
            outputs = model(train)
            
            # Calculate softmax and ross entropy loss
            loss = error(outputs, labels)
            
            # Calculating gradients
            loss.backward()
            
            # Update parameters
            optimizer.step()
            
            count += 1
            
            if count % 50 == 0:
                # Calculate Accuracy         
                correct = 0
                total = 0
                # Predict test dataset
                for images, labels in test_loader:
    
                    test = Variable(images.view(-1, 28*28))
                    
                    # Forward propagation
                    outputs = model(test)
                    
                    # Get predictions from the maximum value
                    predicted = torch.max(outputs.data, 1)[1]
                    
                    # Total number of labels
                    total += len(labels)
    
                    # Total correct predictions
                    correct += (predicted == labels).sum()
                
                accuracy = 100 * correct / float(total)
                
                # store loss and iteration
                loss_list.append(loss.data)
                iteration_list.append(count)
                accuracy_list.append(accuracy)
            if count % 500 == 0:
                # Print Loss
                print('Iteration: {}  Loss: {}  Accuracy: {} %'.format(count, loss.data, accuracy))
    

     结果:

    Iteration: 500  Loss: 0.8311067223548889  Accuracy: 77 %
    Iteration: 1000  Loss: 0.4767582416534424  Accuracy: 87 %
    Iteration: 1500  Loss: 0.21807175874710083  Accuracy: 89 %
    Iteration: 2000  Loss: 0.2915269732475281  Accuracy: 90 %
    Iteration: 2500  Loss: 0.3073478937149048  Accuracy: 91 %
    Iteration: 3000  Loss: 0.12328791618347168  Accuracy: 92 %
    Iteration: 3500  Loss: 0.24098418653011322  Accuracy: 93 %
    Iteration: 4000  Loss: 0.06471655517816544  Accuracy: 93 %
    Iteration: 4500  Loss: 0.3368555009365082  Accuracy: 94 %
    Iteration: 5000  Loss: 0.12026549130678177  Accuracy: 94 %
    Iteration: 5500  Loss: 0.217212975025177  Accuracy: 94 %
    Iteration: 6000  Loss: 0.20914879441261292  Accuracy: 94 %
    Iteration: 6500  Loss: 0.10008767992258072  Accuracy: 95 %
    Iteration: 7000  Loss: 0.13490895926952362  Accuracy: 95 %
    Iteration: 7500  Loss: 0.11741413176059723  Accuracy: 95 %
    Iteration: 8000  Loss: 0.17519493401050568  Accuracy: 95 %
    Iteration: 8500  Loss: 0.06657659262418747  Accuracy: 95 %
    Iteration: 9000  Loss: 0.05512683466076851  Accuracy: 95 %
    Iteration: 9500  Loss: 0.02535334974527359  Accuracy: 96 %
    

     第三部分:可视化展示

    # visualization loss 
    plt.plot(iteration_list,loss_list)
    plt.xlabel("Number of iteration")
    plt.ylabel("Loss")
    plt.title("ANN: Loss vs Number of iteration")
    plt.show()
    
    # visualization accuracy 
    plt.plot(iteration_list,accuracy_list,color = "red")
    plt.xlabel("Number of iteration")
    plt.ylabel("Accuracy")
    plt.title("ANN: Accuracy vs Number of iteration")
    plt.show()
    

     结果:

     
     
  • 相关阅读:
    C#Type类中的IsAssignableFrom、IsInstanceOfType、IsSubclassOf
    C# IsAssignableFrom & IsInstanceOfType & IsSubclassOf & Is
    VS2017桌面应用程序打包成.msi或者.exe
    C# 10个常用特性
    ADO.NET 的六大对象及其关系图
    Expression表达式树(C#)
    表达式树 Expression Trees
    特性 Attribute
    C# 特性(attribute)
    Vue中使用axios
  • 原文地址:https://www.cnblogs.com/zhuozige/p/14696695.html
Copyright © 2011-2022 走看看