zoukankan      html  css  js  c++  java
  • 用pytorch进行CIFAR-10数据集分类

    CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 与 Geoffrey Hinton 收集的一个用于图像识别的数据集,60000个32*32的彩色图像,50000个training data,10000个 test data 有10类,飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车,每类6000张图。与MNIST相比,色彩、颜色噪点较多,同一类物体大小不一、角度不同、颜色不同。

     先要对该数据集进行分类

    步骤如下
    1.使用torchvision加载并预处理CIFAR-10数据集、
    2.定义网络
    3.定义损失函数和优化器
    4.训练网络并更新网络参数
    5.测试网络

     1 import torchvision as tv            #里面含有许多数据集
     2 import torch
     3 import torchvision.transforms as transforms    #实现图片变换处理的包
     4 from torchvision.transforms import ToPILImage
     5 
     6 #使用torchvision加载并预处理CIFAR10数据集
     7 show = ToPILImage()         #可以把Tensor转成Image,方便进行可视化
     8 transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把数据变为tensor并且归一化range [0, 255] -> [0.0,1.0]
     9 trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
    10 trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)
    11 testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
    12 testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=True,num_workers=0)
    13 classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
    14 (data,label) = trainset[100]
    15 print(classes[label])#输出ship
    16 show((data+1)/2).resize((100,100))
    17 dataiter = iter(trainloader)
    18 images, labels = dataiter.next()
    19 print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
    20 show(tv.utils.make_grid((images+1)/2)).resize((400,100))#make_grid的作用是将若干幅图像拼成一幅图像
    21 
    22 #定义网络
    23 import torch.nn as nn
    24 import torch.nn.functional as F
    25 class Net(nn.Module):
    26     def __init__(self):
    27         super(Net,self).__init__()
    28         self.conv1 = nn.Conv2d(3,6,5)
    29         self.conv2 = nn.Conv2d(6,16,5)
    30         self.fc1 = nn.Linear(16*5*5,120)
    31         self.fc2 = nn.Linear(120,84)
    32         self.fc3 = nn.Linear(84,10)
    33     def forward(self,x):
    34         x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
    35         x = F.max_pool2d(F.relu(self.conv2(x)),2)
    36         x = x.view(x.size()[0],-1)
    37         x = F.relu(self.fc1(x))
    38         x = F.relu(self.fc2(x))
    39         x = self.fc3(x)
    40         return  x
    41 
    42 net = Net()
    43 print(net)
    44 
    45 #定义损失函数和优化器
    46 from torch import optim
    47 criterion  = nn.CrossEntropyLoss()#定义交叉熵损失函数
    48 optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
    49 
    50 #训练网络
    51 from torch.autograd  import Variable
    52 for epoch in range(2):
    53     running_loss = 0.0
    54     for i, data in enumerate(trainloader, 0):#enumerate将其组成一个索引序列,利用它可以同时获得索引和值,enumerate还可以接收第二个参数,用于指定索引起始值
    55         inputs, labels = data
    56         inputs, labels = Variable(inputs), Variable(labels)
    57         optimizer.zero_grad()
    58         outputs = net(inputs)
    59         loss  = criterion(outputs, labels)
    60         loss.backward()
    61         optimizer.step()
    62         running_loss += loss.item()
    63         if i % 2000 ==1999:
    64             print('[%d, %5d] loss: %.3f'%(epoch+1,i+1,running_loss/2000))
    65             running_loss = 0.0
    66 print("----------finished training---------")
    67 dataiter = iter(testloader)
    68 images, labels = dataiter.next()
    69 print('实际的label: ',' '.join('%08s'%classes[labels[j]] for j in range(4)))
    70 show(tv.utils.make_grid(images/2 - 0.5)).resize((400,100))#?????
    71 outputs = net(Variable(images))
    72 _, predicted = torch.max(outputs.data,1)#返回最大值和其索引
    73 print('预测结果:',' '.join('%5s'%classes[predicted[j]] for j in range(4)))
    74 correct = 0
    75 total = 0
    76 for data in testloader:
    77     images, labels = data
    78     outputs = net(Variable(images))
    79     _, predicted = torch.max(outputs.data, 1)
    80     total +=labels.size(0)
    81     correct +=(predicted == labels).sum()
    82 print('10000张测试集中的准确率为: %d %%'%(100*correct/total))
    83 if torch.cuda.is_available():
    84     net.cuda()
    85     images = images.cuda()
    86     labels = labels.cuda()
    87     output = net(Variable(images))
    88     loss = criterion(output, Variable(labels))

    学习率太大会很难逼近最优值,所以要注意在数据集小的情况下学习率尽量小一些,epoch尽量大一些。

    这个例子是陈云的深度学习pytorch框架书上的一个demo,运行该代码需要注意的是数据集的下载问题,因为运行程序很可能数据集下载很慢或者直接下载失败,因此推荐使用迅雷根据指定网址直接下载,半分钟就可以下载好。

  • 相关阅读:
    自考新教材-p173_3(1)
    自考新教材-p148_5(2)
    自考新教材-p148_5(1)
    自考新教材-p148_4
    自考新教材-p147_3
    自考新教材-p146_4(2)
    python 模块 chardet报错解决方法:下载及介绍
    第 52 讲:论一只爬虫的自我修养
    第 51 讲: _name_属性
    Python 培训第一讲
  • 原文地址:https://www.cnblogs.com/henuliulei/p/11981109.html
Copyright © 2011-2022 走看看