zoukankan      html  css  js  c++  java
  • 【邱希鹏】nndl-chap5-数字识别(pytorch)

    1. 导入包

    import os
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.utils.data as Data
    import torchvision
    import torch.nn.functional as F
    import numpy as np
    from torchvision import datasets, transforms
    learning_rate = 1e-4
    keep_prob_rate = 0.7 #
    max_epoch = 3
    BATCH_SIZE = 50
    
    DOWNLOAD_MNIST = False
    if not(os.path.exists('MNIST')) or not os.listdir('MNIST'):
        # not mnist dir or mnist is empyt dir
        DOWNLOAD_MNIST = True
    

    2. 导入数据

    train_data = torchvision.datasets.MNIST(root='./', train = True, download=DOWNLOAD_MNIST,
                                            transform = torchvision.transforms.Compose([
                                                transforms.ToTensor(),
                                            ]))
    train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)
    
    test_data = torchvision.datasets.MNIST(root = './', train = False,
                                          transform = transforms.Compose([
                                              transforms.ToTensor(),
                                          ]))
    test_loader = Data.DataLoader(dataset = test_data, batch_size = BATCH_SIZE, shuffle = True)
    
    test_x = Variable(torch.unsqueeze(test_data.test_data,dim  = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
    test_y = test_data.test_labels[:500].numpy()
    print(test_x.shape)
    print(test_y.shape)
    
    # torch.Size([500, 1, 28, 28])
    # (500,)
    

    3. CNN模型

    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Sequential(
                nn.Conv2d( # ???
                    # patch 7 * 7 ; 1  in channels ; 32 out channels ; ; stride is 1
                    # padding style is same(that means the convolution opration's input and output have the same size)
                    in_channels = 1     ,  
                    out_channels = 32   ,
                    kernel_size = 7     ,
                    stride = 1          ,
                    padding = 0    ,
                ),
                nn.ReLU(),        # activation function
                nn.MaxPool2d(2),  # pooling operation
            )
            self.conv2 = nn.Sequential( # ???
                # line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1
                # line 2 : choosing your activation funciont
                # line 3 : pooling operation function.
                nn.Conv2d(
                    in_channels = 32,
                    out_channels = 64,
                    kernel_size = 5,
                    stride = 1,
                    padding = 0,
                ),
                nn.ReLU(),
                nn.MaxPool2d(1),
            )
            self.out1 = nn.Linear(7*7*64 , 1024 , bias= True)   # full connection layer one
    
            self.dropout = nn.Dropout(keep_prob_rate)
            self.out2 = nn.Linear(1024, 10, bias=True)
    
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = x.view(-1, 64*7*7)  # flatten the output of coonv2 to (batch_size ,64 * 7 * 7)    # ???
            out1 = self.out1(x)
            out1 = F.relu(out1)
            out1 = self.dropout(out1)
            out2 = self.out2(out1)
            output = F.softmax(out2)
            return output
    

    4. 训练与测试

    def test(cnn):
        global prediction
        y_pre = cnn(test_x)
        _,pre_index= torch.max(y_pre,1)
        pre_index= pre_index.view(-1)
        prediction = pre_index.data.numpy()
        correct  = np.sum(prediction == test_y)
        return correct / 500.0
    
    
    def train(cnn):
        optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )
        loss_func = nn.CrossEntropyLoss()
        for epoch in range(max_epoch):
            for step, (x_, y_) in enumerate(train_loader):
                x ,y= Variable(x_),Variable(y_)
                output = cnn(x)  
                loss = loss_func(output, y)   # 标量
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                
                if step != 0 and step % 20 ==0:
                    print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )
    

    4.1 训练

    cnn = CNN()
    train(cnn)
    
    ========== 20 ===== ===== test accuracy is  0.25 ==========
    ========== 40 ===== ===== test accuracy is  0.458 ==========
    ========== 60 ===== ===== test accuracy is  0.572 ==========
    ========== 80 ===== ===== test accuracy is  0.624 ==========
    ========== 100 ===== ===== test accuracy is  0.638 ==========
    ========== 120 ===== ===== test accuracy is  0.718 ==========
    ========== 140 ===== ===== test accuracy is  0.744 ==========
    ========== 160 ===== ===== test accuracy is  0.796 ==========
    ========== 180 ===== ===== test accuracy is  0.806 ==========
    ========== 200 ===== ===== test accuracy is  0.808 ==========
    ========== 220 ===== ===== test accuracy is  0.844 ==========
    ========== 240 ===== ===== test accuracy is  0.84 ==========
    ========== 260 ===== ===== test accuracy is  0.864 ==========
    ========== 280 ===== ===== test accuracy is  0.86 ==========
    ========== 300 ===== ===== test accuracy is  0.878 ==========
    ========== 320 ===== ===== test accuracy is  0.868 ==========
    ========== 340 ===== ===== test accuracy is  0.876 ==========
    ========== 360 ===== ===== test accuracy is  0.862 ==========
    ========== 380 ===== ===== test accuracy is  0.86 ==========
    ========== 400 ===== ===== test accuracy is  0.892 ==========
    ========== 420 ===== ===== test accuracy is  0.87 ==========
    ========== 440 ===== ===== test accuracy is  0.882 ==========
    ========== 460 ===== ===== test accuracy is  0.898 ==========
    ========== 480 ===== ===== test accuracy is  0.892 ==========
    ========== 500 ===== ===== test accuracy is  0.884 ==========
    ========== 520 ===== ===== test accuracy is  0.892 ==========
    ========== 540 ===== ===== test accuracy is  0.892 ==========
    ========== 560 ===== ===== test accuracy is  0.902 ==========
    ========== 580 ===== ===== test accuracy is  0.902 ==========
    ========== 600 ===== ===== test accuracy is  0.904 ==========
    ========== 620 ===== ===== test accuracy is  0.902 ==========
    ========== 640 ===== ===== test accuracy is  0.904 ==========
    ========== 660 ===== ===== test accuracy is  0.906 ==========
    ========== 680 ===== ===== test accuracy is  0.908 ==========
    ========== 700 ===== ===== test accuracy is  0.922 ==========
    ========== 720 ===== ===== test accuracy is  0.916 ==========
    ========== 740 ===== ===== test accuracy is  0.918 ==========
    ========== 760 ===== ===== test accuracy is  0.906 ==========
    ========== 780 ===== ===== test accuracy is  0.924 ==========
    ========== 800 ===== ===== test accuracy is  0.928 ==========
    ========== 820 ===== ===== test accuracy is  0.918 ==========
    ========== 840 ===== ===== test accuracy is  0.922 ==========
    ========== 860 ===== ===== test accuracy is  0.918 ==========
    ========== 880 ===== ===== test accuracy is  0.93 ==========
    ========== 900 ===== ===== test accuracy is  0.924 ==========
    ========== 920 ===== ===== test accuracy is  0.932 ==========
    ========== 940 ===== ===== test accuracy is  0.934 ==========
    ========== 960 ===== ===== test accuracy is  0.926 ==========
    ========== 980 ===== ===== test accuracy is  0.932 ==========
    ========== 1000 ===== ===== test accuracy is  0.934 ==========
    ========== 1020 ===== ===== test accuracy is  0.926 ==========
    ========== 1040 ===== ===== test accuracy is  0.924 ==========
    ========== 1060 ===== ===== test accuracy is  0.934 ==========
    ========== 1080 ===== ===== test accuracy is  0.932 ==========
    ========== 1100 ===== ===== test accuracy is  0.934 ==========
    ========== 1120 ===== ===== test accuracy is  0.936 ==========
    ========== 1140 ===== ===== test accuracy is  0.936 ==========
    ========== 1160 ===== ===== test accuracy is  0.93 ==========
    ========== 1180 ===== ===== test accuracy is  0.932 ==========
    ========== 20 ===== ===== test accuracy is  0.934 ==========
    ========== 40 ===== ===== test accuracy is  0.946 ==========
    ========== 60 ===== ===== test accuracy is  0.94 ==========
    ========== 80 ===== ===== test accuracy is  0.946 ==========
    ========== 100 ===== ===== test accuracy is  0.946 ==========
    ========== 120 ===== ===== test accuracy is  0.944 ==========
    ========== 140 ===== ===== test accuracy is  0.946 ==========
    ========== 160 ===== ===== test accuracy is  0.956 ==========
    ========== 180 ===== ===== test accuracy is  0.936 ==========
    ========== 200 ===== ===== test accuracy is  0.95 ==========
    ========== 220 ===== ===== test accuracy is  0.956 ==========
    ========== 240 ===== ===== test accuracy is  0.946 ==========
    ========== 260 ===== ===== test accuracy is  0.944 ==========
    ========== 280 ===== ===== test accuracy is  0.944 ==========
    ========== 300 ===== ===== test accuracy is  0.954 ==========
    ========== 320 ===== ===== test accuracy is  0.964 ==========
    ========== 340 ===== ===== test accuracy is  0.95 ==========
    ========== 360 ===== ===== test accuracy is  0.962 ==========
    ========== 380 ===== ===== test accuracy is  0.948 ==========
    ========== 400 ===== ===== test accuracy is  0.96 ==========
    ========== 420 ===== ===== test accuracy is  0.946 ==========
    ========== 440 ===== ===== test accuracy is  0.96 ==========
    ========== 460 ===== ===== test accuracy is  0.948 ==========
    ========== 480 ===== ===== test accuracy is  0.95 ==========
    ========== 500 ===== ===== test accuracy is  0.958 ==========
    ========== 520 ===== ===== test accuracy is  0.954 ==========
    ========== 540 ===== ===== test accuracy is  0.948 ==========
    ========== 560 ===== ===== test accuracy is  0.958 ==========
    ========== 580 ===== ===== test accuracy is  0.948 ==========
    ========== 600 ===== ===== test accuracy is  0.96 ==========
    ========== 620 ===== ===== test accuracy is  0.96 ==========
    ========== 640 ===== ===== test accuracy is  0.96 ==========
    ========== 660 ===== ===== test accuracy is  0.95 ==========
    ========== 680 ===== ===== test accuracy is  0.962 ==========
    ========== 700 ===== ===== test accuracy is  0.964 ==========
    ========== 720 ===== ===== test accuracy is  0.962 ==========
    ========== 740 ===== ===== test accuracy is  0.96 ==========
    ========== 760 ===== ===== test accuracy is  0.954 ==========
    ========== 780 ===== ===== test accuracy is  0.956 ==========
    ========== 800 ===== ===== test accuracy is  0.962 ==========
    ========== 820 ===== ===== test accuracy is  0.962 ==========
    ========== 840 ===== ===== test accuracy is  0.968 ==========
    ========== 860 ===== ===== test accuracy is  0.962 ==========
    ========== 880 ===== ===== test accuracy is  0.972 ==========
    ========== 900 ===== ===== test accuracy is  0.96 ==========
    ========== 920 ===== ===== test accuracy is  0.958 ==========
    ========== 940 ===== ===== test accuracy is  0.966 ==========
    ========== 960 ===== ===== test accuracy is  0.972 ==========
    ========== 980 ===== ===== test accuracy is  0.964 ==========
    ========== 1000 ===== ===== test accuracy is  0.968 ==========
    ========== 1020 ===== ===== test accuracy is  0.968 ==========
    ========== 1040 ===== ===== test accuracy is  0.956 ==========
    ========== 1060 ===== ===== test accuracy is  0.96 ==========
    ========== 1080 ===== ===== test accuracy is  0.97 ==========
    ========== 1100 ===== ===== test accuracy is  0.968 ==========
    ========== 1120 ===== ===== test accuracy is  0.964 ==========
    ========== 1140 ===== ===== test accuracy is  0.97 ==========
    ========== 1160 ===== ===== test accuracy is  0.97 ==========
    ========== 1180 ===== ===== test accuracy is  0.96 ==========
    ========== 20 ===== ===== test accuracy is  0.96 ==========
    ========== 40 ===== ===== test accuracy is  0.962 ==========
    ========== 60 ===== ===== test accuracy is  0.97 ==========
    ========== 80 ===== ===== test accuracy is  0.958 ==========
    ========== 100 ===== ===== test accuracy is  0.966 ==========
    ========== 120 ===== ===== test accuracy is  0.962 ==========
    ========== 140 ===== ===== test accuracy is  0.968 ==========
    ========== 160 ===== ===== test accuracy is  0.972 ==========
    ========== 180 ===== ===== test accuracy is  0.972 ==========
    ========== 200 ===== ===== test accuracy is  0.978 ==========
    ========== 220 ===== ===== test accuracy is  0.968 ==========
    ========== 240 ===== ===== test accuracy is  0.956 ==========
    ========== 260 ===== ===== test accuracy is  0.97 ==========
    ========== 280 ===== ===== test accuracy is  0.964 ==========
    ========== 300 ===== ===== test accuracy is  0.97 ==========
    ========== 320 ===== ===== test accuracy is  0.972 ==========
    ========== 340 ===== ===== test accuracy is  0.976 ==========
    ========== 360 ===== ===== test accuracy is  0.968 ==========
    ========== 380 ===== ===== test accuracy is  0.97 ==========
    ========== 400 ===== ===== test accuracy is  0.974 ==========
    ========== 420 ===== ===== test accuracy is  0.974 ==========
    ========== 440 ===== ===== test accuracy is  0.968 ==========
    ========== 460 ===== ===== test accuracy is  0.976 ==========
    ========== 480 ===== ===== test accuracy is  0.97 ==========
    ========== 500 ===== ===== test accuracy is  0.96 ==========
    ========== 520 ===== ===== test accuracy is  0.966 ==========
    ========== 540 ===== ===== test accuracy is  0.974 ==========
    ========== 560 ===== ===== test accuracy is  0.974 ==========
    ========== 580 ===== ===== test accuracy is  0.972 ==========
    ========== 600 ===== ===== test accuracy is  0.974 ==========
    ========== 620 ===== ===== test accuracy is  0.97 ==========
    ========== 640 ===== ===== test accuracy is  0.974 ==========
    ========== 660 ===== ===== test accuracy is  0.976 ==========
    ========== 680 ===== ===== test accuracy is  0.97 ==========
    ========== 700 ===== ===== test accuracy is  0.974 ==========
    ========== 720 ===== ===== test accuracy is  0.962 ==========
    ========== 740 ===== ===== test accuracy is  0.98 ==========
    ========== 760 ===== ===== test accuracy is  0.976 ==========
    ========== 780 ===== ===== test accuracy is  0.966 ==========
    ========== 800 ===== ===== test accuracy is  0.968 ==========
    ========== 820 ===== ===== test accuracy is  0.974 ==========
    ========== 840 ===== ===== test accuracy is  0.964 ==========
    ========== 860 ===== ===== test accuracy is  0.974 ==========
    ========== 880 ===== ===== test accuracy is  0.974 ==========
    ========== 900 ===== ===== test accuracy is  0.982 ==========
    ========== 920 ===== ===== test accuracy is  0.972 ==========
    ========== 940 ===== ===== test accuracy is  0.974 ==========
    ========== 960 ===== ===== test accuracy is  0.976 ==========
    ========== 980 ===== ===== test accuracy is  0.976 ==========
    ========== 1000 ===== ===== test accuracy is  0.984 ==========
    ========== 1020 ===== ===== test accuracy is  0.976 ==========
    ========== 1040 ===== ===== test accuracy is  0.976 ==========
    ========== 1060 ===== ===== test accuracy is  0.982 ==========
    ========== 1080 ===== ===== test accuracy is  0.974 ==========
    ========== 1100 ===== ===== test accuracy is  0.976 ==========
    ========== 1120 ===== ===== test accuracy is  0.974 ==========
    ========== 1140 ===== ===== test accuracy is  0.98 ==========
    ========== 1160 ===== ===== test accuracy is  0.98 ==========
    ========== 1180 ===== ===== test accuracy is  0.978 ==========
    

    4.2 测试

    def predict(test_x, idx):
        y_pre = cnn(test_x)
        print(y_pre.shape)
        _, pre_index = torch.max(y_pre, 1)
        prediction = pre_index.data.numpy()
        print(prediction)
        print("img: ", test_y[idx].data.numpy())
    
    import matplotlib.pylab as plt
    
    def showTorchImage(image):
        mode = transforms.ToPILImage()(image)
        plt.imshow(mode)
        plt.show()
        
    idx = 18
    showTorchImage(test_x[idx, :, :, :])
    predict(test_x, idx)
    

    torch.Size([50, 10])
    [7 8 8 3 1 5 1 6 9 4 3 5 8 1 7 1 6 9 2 7 6 2 3 9 5 1 4 7 0 0 5 0 9 2 9 2 6
    3 6 1 9 2 5 7 2 0 5 6 2 6]
    img: 2

  • 相关阅读:
    写个perl程序自动下载《南方周末》(2005年12月最后一期,38版,值得一看)
    Android 关于inflate
    Android读取系统相册图片并获得绝对地址
    Android设置一个SubMenu来更改背景颜色
    ExpandableListView(可展开的列表组件)使用方法
    Android自定义Tabs文字,背景
    Android上开发新浪微博OAuth2.0认证
    Android线程显示进度框
    Android http get/post传递参数
    总结:Upate field which lookups from Content Types
  • 原文地址:https://www.cnblogs.com/douzujun/p/13454960.html
Copyright © 2011-2022 走看看