zoukankan      html  css  js  c++  java
  • 图像分类器

    1.使用Torchvision加载和标准化CIFAR10训练和测试数据集
    2.定义卷积神经网络
    3.定义损失函数
    4.根据训练数据训练网络
    5.在测试数据上测试网络

    1.下载及正则化数据集

    import torch
    import torchvision
    import torchvision.transforms as transforms
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    (base) [root@pyspark data]# python cifar-10-python.py
    Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
    100%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉| 170483712/170498071 [5:29:18<00:00, 36983.49it/s]Extracting ./data/cifar-10-python.tar.gz to ./data
    Files already downloaded and verified
    170500096it [5:29:28, 8624.70it/s]
    (base) [root@pyspark data]# ls
    cifar-10-python.py data
    (base) [root@pyspark data]# cd data/
    (base) [root@pyspark data]# ls
    cifar-10-batches-py cifar-10-python.tar.gz


    2.检查训练数据
    import matplotlib.pyplot as plt
    import numpy as np
    
    # functions to show an image
    
    
    def imshow(img):
        img = img / 2 + 0.5     # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()
    
    
    # get some random training images
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    
    # show images
    imshow(torchvision.utils.make_grid(images))
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
    3.定义卷积神经网络
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    net = Net()
    定义损失函数及优化器
    import torch.optim as optim
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    训练神经网络
    for epoch in range(2):  # loop over the dataset multiple times
    
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
    
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    
    print('Finished Training')
    保存训练好的神经网络
    PATH = './cifar_net.pth'
    torch.save(net.state_dict(), PATH)
    在测试数据集上测试神经网络
    dataiter = iter(testloader)
    images, labels = dataiter.next()
    
    # print images
    imshow(torchvision.utils.make_grid(images))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
     
  • 相关阅读:
    在IE浏览器中url传参长度问题
    Linq语句的认识
    关于选择表达式以及判断语句的书写,可以让代码更加的清晰。
    C#/对线程的认识
    Js/如何修改easyui修饰的input的val值
    Java Lambda表达式中的this
    MySQL USING关键词/USING()函数的使用
    复杂SQL查询
    Java 修饰符
    Git:idea中将当前分支修改的内容提交到其他分支上
  • 原文地址:https://www.cnblogs.com/songyuejie/p/12003468.html
Copyright © 2011-2022 走看看