zoukankan      html  css  js  c++  java
  • 卷积神经网络

    """
    手写字体的训练
    """
    import os
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.utils.data as Data
    import torchvision
    import matplotlib.pyplot as plt
    
    # 超参数
    EPOCH = 1
    BATCH_SIZE = 50
    LR = 0.001
    DOWNLOAD_MNIST = False
    
    # 确认有没有mnist数据集
    if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
        DOWNLOAD_MNIST = True
    
    # 在mnist官网下载训练数据集
    train_data = torchvision.datasets.MNIST(
        root='./mnist/', # 文件下载后的保存路径
        train=True, # True代表训练数据集,False代表测试数据集
        transform=torchvision.transforms.ToTensor(), # 将下载后的数据转换为tensor格式,并将数据归一化到0到1之间                                                    
        download=DOWNLOAD_MNIST,# 是否下载数据集
    )
    
    print(train_data.train_data.size())                 # (60000, 28, 28)
    print(train_data.train_labels.size())               # (60000)
    
    # 显示出第一张图片
    plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
    plt.title('%i' % train_data.train_labels[0])
    plt.show()
    
    # 将以上数据集分为多批
    train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    
    # 测试数据集(还没有经过transform操作,其数据范围是0到255之间)
    test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
    
    # 将以下数据归一化到0到1之间,所以要除以255
    test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
    test_y = test_data.test_labels[:2000]
    
    
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Sequential(         # 图片的形状为(1, 28, 28)
                nn.Conv2d(
                    in_channels=1,              # 被卷积的通道数
                    out_channels=16,            # 输出的通道数
                    kernel_size=5,              # 卷积核的大小为(5x5)
                    stride=1,                   # 卷积核的移动步数为1
                    padding=2,                  # 图片的拓展圈数
                ),                              # 输出形状为(16, 28, 28)
                nn.ReLU(),                      # 激活函数
                nn.MaxPool2d(kernel_size=2),    # 最大池化后输出形状为(16, 14, 14)
            )
            self.conv2 = nn.Sequential(         # 输入形状为(16, 14, 14)
                nn.Conv2d(16, 32, 5, 1, 2),     # 输出形状为(32, 14, 14)
                nn.ReLU(),                      # 激活函数
                nn.MaxPool2d(2),                # 最大池化后输出形状为(32, 7, 7)
            )
            self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)          # x的形状为(batch_size, 32, 7, 7)
            x = x.view(x.size(0), -1)  # 执行后x的形状为(batch_size, 32 * 7 * 7)
            output = self.out(x)
            return output
    
    
    cnn = CNN()
    print(cnn)  # 打印出网络结构
    
    # 优化所有的网络参数
    optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
    
    #计算损失值
    loss_func = nn.CrossEntropyLoss()
    
    # 训练及测试
    for epoch in range(EPOCH):
        for step, (x, y) in enumerate(train_loader):
            b_x = Variable(x)
            b_y = Variable(y)
    
            output = cnn(b_x)
            loss = loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if step % 50 == 0:
                test_output = cnn(test_x)
                pred_y = torch.max(test_output, 1)[1].data.squeeze()
                accuracy = sum(pred_y == test_y) / float(test_y.size(0))
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
    
    # 打印出前十个图片的预测效果
    test_output, _ = cnn(test_x[:10])
    pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
    print(pred_y, 'prediction number')
    print(test_y[:10].numpy(), 'real number')
  • 相关阅读:
    bzoj2733 永无乡 平衡树按秩合并
    bzoj2752 高速公路 线段树
    bzoj1052 覆盖问题 二分答案 dfs
    bzoj1584 打扫卫生 dp
    bzoj1854 游戏 二分图
    bzoj3316 JC loves Mkk 二分答案 单调队列
    bzoj3643 Phi的反函数 数学 搜索
    有一种恐怖,叫大爆搜
    BZOJ3566 概率充电器 概率dp
    一些奇奇怪怪的过题思路
  • 原文地址:https://www.cnblogs.com/czz0508/p/10336794.html
Copyright © 2011-2022 走看看