""" 手写字体的训练 """ import os import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # 超参数 EPOCH = 1 BATCH_SIZE = 50 LR = 0.001 DOWNLOAD_MNIST = False # 确认有没有mnist数据集 if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'): DOWNLOAD_MNIST = True # 在mnist官网下载训练数据集 train_data = torchvision.datasets.MNIST( root='./mnist/', # 文件下载后的保存路径 train=True, # True代表训练数据集,False代表测试数据集 transform=torchvision.transforms.ToTensor(), # 将下载后的数据转换为tensor格式,并将数据归一化到0到1之间 download=DOWNLOAD_MNIST,# 是否下载数据集 ) print(train_data.train_data.size()) # (60000, 28, 28) print(train_data.train_labels.size()) # (60000) # 显示出第一张图片 plt.imshow(train_data.train_data[0].numpy(), cmap='gray') plt.title('%i' % train_data.train_labels[0]) plt.show() # 将以上数据集分为多批 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 测试数据集(还没有经过transform操作,其数据范围是0到255之间) test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) # 将以下数据归一化到0到1之间,所以要除以255 test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels[:2000] class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # 图片的形状为(1, 28, 28) nn.Conv2d( in_channels=1, # 被卷积的通道数 out_channels=16, # 输出的通道数 kernel_size=5, # 卷积核的大小为(5x5) stride=1, # 卷积核的移动步数为1 padding=2, # 图片的拓展圈数 ), # 输出形状为(16, 28, 28) nn.ReLU(), # 激活函数 nn.MaxPool2d(kernel_size=2), # 最大池化后输出形状为(16, 14, 14) ) self.conv2 = nn.Sequential( # 输入形状为(16, 14, 14) nn.Conv2d(16, 32, 5, 1, 2), # 输出形状为(32, 14, 14) nn.ReLU(), # 激活函数 nn.MaxPool2d(2), # 最大池化后输出形状为(32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 10) # 全连接 def forward(self, x): x = self.conv1(x) x = self.conv2(x) # x的形状为(batch_size, 32, 7, 7) x = x.view(x.size(0), -1) # 执行后x的形状为(batch_size, 32 * 7 * 7) output = self.out(x) return output cnn = CNN() print(cnn) # 打印出网络结构 # 优化所有的网络参数 optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) #计算损失值 loss_func = nn.CrossEntropyLoss() # 训练及测试 for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): b_x = Variable(x) b_y = Variable(y) output = cnn(b_x) loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step % 50 == 0: test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.squeeze() accuracy = sum(pred_y == test_y) / float(test_y.size(0)) print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy) # 打印出前十个图片的预测效果 test_output, _ = cnn(test_x[:10]) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() print(pred_y, 'prediction number') print(test_y[:10].numpy(), 'real number')