zoukankan      html  css  js  c++  java
  • Pytorch学习:CIFAR-10分类

    最近在学习Pytorch,先照着别人的代码过一遍,加油!!!

    加载数据集

    # 加载数据集及预处理
    import torchvision as tv
    import torchvision.transforms as transforms
    from torchvision.transforms import ToPILImage
    import torch as t
    show=ToPILImage() #可以将Tensor转成Image,方便可视化

    划分数据集为训练集和测试集

    #定义对数据的预处理
    transform=transforms.Compose([
        transforms.ToTensor(),  #转为Tensor
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #归一化
    ])
    
    #训练集
    trainset=tv.datasets.CIFAR10(
        root='/home/cy/data',
        train=True,
        download=True,
        transform=transform
    )
    
    trainloader=t.utils.data.DataLoader(
        trainset,
        batch_size=4,
        shuffle=True,
        num_workers=2
    )
    
    testset=tv.datasets.CIFAR10(
        '/home/cy/data/',
        train=False,
        download=True,
        transform=transform
    )
    
    testloader=t.utils.data.DataLoader(
        testset,
        batch_size=4,
        shuffle=False,
        num_workers=2
    )
    
    classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
    Files already downloaded and verified
    Files already downloaded and verified

    可视化看下图片效果
    (data, label)=trainset[100]
    print(classes[label])
    
    #(data+1)是为了还原被归一化的数据
    show((data+1)/2).resize((100,100))

    展示一个mini-batch中的图片

    dataiter=iter(trainloader)
    images,labels=dataiter.next() #返回4张图片及标签
    print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
    show(tv.utils.make_grid((images+1)/2)).resize((400,100))

    定义网络结构,挺方便的

    ## 定义网络
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            self.conv1=nn.Conv2d(3,6,5)
            self.conv2=nn.Conv2d(6,16,5)
            self.fc1=nn.Linear(16*5*5,120)
            self.fc2=nn.Linear(120,84)
            self.fc3=nn.Linear(84,10)
            
            
        def forward(self,x):
            x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
            x=F.max_pool2d(F.relu(self.conv2(x)),2)
            x=x.view(x.size()[0],-1)
            x=F.relu(self.fc1(x))
            x=F.relu(self.fc2(x))
            x=self.fc3(x)
            return x
    
    net=Net()
    print(net)
    Net(
      (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
      (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
      (fc1): Linear(in_features=400, out_features=120, bias=True)
      (fc2): Linear(in_features=120, out_features=84, bias=True)
      (fc3): Linear(in_features=84, out_features=10, bias=True)
    )

    定义损失函数和优化器
    ## 定义损失函数和优化器
    from torch import optim
    criterion=nn.CrossEntropyLoss()  # 交叉熵损失函数
    optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #随机梯度下降,stochastic gradient descent

    开始训练网络

    一共有三个步骤。输入数据,前向传播+反向传播,更新参数

    from torch.autograd import Variable
    
    for epoch in range(2):
        running_loss=0.0
        for i,data in enumerate(trainloader,0):
            #输入数据
            inputs,labels=data
            inputs,labels=Variable(inputs),Variable(labels)
            
            #梯度清零
            optimizer.zero_grad()
            
            #forward+backward
            outputs=net(inputs)
            loss=criterion(outputs,labels)
            loss.backward()
            
            #更新参数
            optimizer.step()
            
            #打印log信息
            #running_loss +=loss.data[0]
            running_loss +=loss.item()
            if i%2000 ==1999:   #每2000个batch打印一次训练状态
                print('[%d, %5d] loss: %.3f' 
                     %(epoch+1,i+1,running_loss / 2000))
                running_loss=0.0
    print('Finished Training')

    检查一下网络在一个batch内的效果如何

    ## 检验网络效果
    dataiter=iter(testloader)
    images,labels=dataiter.next() #一个batch返回4张图片
    print('实际的label: ',' '.join(
                '%08s'%classes[labels[j]] for j in range(4)))
    show(tv.utils.make_grid(images/2 -0.5)).resize((400,100))
    
    # 计算网络预测的label
    outputs=net(Variable(images))
    _,predicted=t.max(outputs.data,1)
    print('预测结果: ',' '.join('%5s'
            % classes[predicted[j]] for j in range(4)))

    测试集上计算正确率

    correct=0
    total=0
    for data in testloader:
        images,labels=data
        outputs=net(Variable(images))
        _,predicted=t.max(outputs.data,1)
        total +=labels.size(0)
        correct +=(predicted==labels).sum()
        
    print('1000张测试集中的准确率为: %d  %%' %(100* correct/total))
    1000张测试集中的准确率为: 52  %

    可以看到,在CIFAR-10上的正确率为52%,网络训练还是有些效果的。

  • 相关阅读:
    波段是金牢记六大诀窍
    zk kafka mariadb scala flink integration
    Oracle 体系结构详解
    图解 Database Buffer Cache 内部原理(二)
    SQL Server 字符集介绍及修改方法演示
    SQL Server 2012 备份与还原详解
    SQL Server 2012 查询数据库中所有表的名称和行数
    SQL Server 2012 查询数据库中表格主键信息
    SQL Server 2012 查询数据库中所有表的索引信息
    图解 Database Buffer Cache 内部原理(一)
  • 原文地址:https://www.cnblogs.com/keeptry/p/13943820.html
Copyright © 2011-2022 走看看