zoukankan      html  css  js  c++  java
  • Pytorc搭建AlexNet实现CAFAR10数据集分类(学习记录)

    1.加载数据集:torchvision.dataset.CIFAR10

      

    • transform设置:transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])将图像作为tensor输入,同时进行归一化处理,均值为0.5, 方差为0.5,三个分量分别为RGB值。
    • DataLoader是pytorch常用数据加载函数,进行批次数据传送,batch_size的设定指明每次输入数据量,因此预测时与之对应
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5,0.5)])
    trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
    trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=0)

    2.定义AlexNet网络结构

    class Net(nn.Module):
    def __init__(self):
    super(Net,self).__init__()


    self.conv1=nn.Conv2d(3,6,5)
    self.pool1=nn.MaxPool2d(2,2)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

    def forward(self,x):
    out=self.pool1(F.relu(self.conv1(x)))
    out=self.pool1(F.relu(self.conv2(out)))
    out=out.view(-1,16*5*5)
    out=F.relu(self.fc1(out))
    out=F.relu(self.fc2(out))
    out=self.fc3(out)
    return out
    net=Net()
    print(net)

    3、定义损失函数和优化器

    cost=nn.CrossEntropyLoss()

    optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

    4.训练网络

    for epoch in range(2):
    running_loss=0.0
    for i,data in enumerate(trainloader,0):
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    optimizer.zero_grad()
    outputs=net(inputs)
    loss=cost(outputs,labels)
    loss.backward()
    optimizer.step()

    running_loss+=loss.item()
    if i % 2000==1999:
    print('[%d, %5d] loss: %.3f'%(epoch + 1, i + 1, running_loss / 2000))
    running_loss=0.001
    print('done')

     5、测试整个预测结果

    correct=0.0
    total=0
    for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,pred=torch.max(outputs.data,1)
    total+=labels.size(0)
    correct+=(pred==labels).sum()
    print('average Accuracy: %d %%' %(100*correct / total))

    6、展示各类别预测结果

    class_correct =list(0. for i in range(10))
    class_total=list(0. for i in range(10))
    for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,pred=torch.max(outputs.data,1)
    c=(pred==labels).squeeze()

    for i in range(4):
    label=labels[i]
    class_correct[label]+=float(c[i])
    class_total[label]+=1
    print('each class accuracy: ')

    7、实验结果

    
    

     

     8、完整代码

    import torch
    import torchvision
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.autograd import Variable
    import torchvision.transforms as transforms

    import matplotlib.pyplot as plt
    import numpy as np

    def imshow(img):
    img=img/2+0.5
    np_img=img.numpy()
    plt.imshow(np.transpose(np_img,(1,2,0)))

    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5,0.5)])
    trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
    trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=0)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    class Net(nn.Module):
    def __init__(self):
    super(Net,self).__init__()


    self.conv1=nn.Conv2d(3,6,5)
    self.pool1=nn.MaxPool2d(2,2)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

    def forward(self,x):
    out=self.pool1(F.relu(self.conv1(x)))
    out=self.pool1(F.relu(self.conv2(out)))
    out=out.view(-1,16*5*5)
    out=F.relu(self.fc1(out))
    out=F.relu(self.fc2(out))
    out=self.fc3(out)
    return out
    net=Net()
    print(net)

    cost=nn.CrossEntropyLoss()

    optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)


    for epoch in range(2):
    running_loss=0.0
    for i,data in enumerate(trainloader,0):
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    optimizer.zero_grad()
    outputs=net(inputs)
    loss=cost(outputs,labels)
    loss.backward()
    optimizer.step()

    running_loss+=loss.item()
    if i % 2000==1999:
    print('[%d, %5d] loss: %.3f'%(epoch + 1, i + 1, running_loss / 2000))
    running_loss=0.001
    print('done')

    correct=0.0
    total=0
    for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,pred=torch.max(outputs.data,1)
    total+=labels.size(0)
    correct+=(pred==labels).sum()
    print('average Accuracy: %d %%' %(100*correct / total))

    class_correct =list(0. for i in range(10))
    class_total=list(0. for i in range(10))
    for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,pred=torch.max(outputs.data,1)
    c=(pred==labels).squeeze()

    for i in range(4):
    label=labels[i]
    class_correct[label]+=float(c[i])
    class_total[label]+=1
    print('each class accuracy: ')

    for i in range(10):
    print('Accuracy: %6s %2d %%' %(classes[i], 100 * class_correct[i] / class_total[i]))





    
    
    
    
  • 相关阅读:
    python数据类型和数据运算
    python 模块介绍
    Dictionary<Key,Value>的用法
    不用Invoke就等用 Control.CheckForIllegalCrossThreadCalls = false;
    多线程+委托的安全访问(invoke)
    Lambda 表达式型的排序法
    System.Windows.Forms.AxHost.InvalidActiveXStateException”类型的异常在 ESRI.ArcGIS.AxControls.dll 中发生,但未在用户代码中进行处理
    无法嵌入互操作类型“ESRI.ArcGIS.Display.SimpleFillSymbolClass”。请改用适用的接口。
    JavaScript中样式,方法 函数的应用
    Arcgis Engine最短路径分析
  • 原文地址:https://www.cnblogs.com/xufeng123/p/13744861.html
Copyright © 2011-2022 走看看