zoukankan      html  css  js  c++  java
  • LeNet 网络进行猫狗大战

    最近给学生布置了猫狗大战的作业,是我自己拍脑袋想的。我发现大多同学做的并太理想,主要原因是因为对pytorch不太熟悉。中秋假期我也做了这个作业,效果虽然并不算好,但可以做为一个范例提供给初学者学习。(其实我写的网络和大家差不多,并不是 LeNet,是一个只有卷积、池化、全连接的简单CNN)

    大家普遍反映有两个问题:1、网络不收敛;2、Colab上训练时间太长。

    问题1解决方法: 网络不收敛,可以使用更深的网络。网络变深时,sigmoid 激活函数容易进入饱和区,就不收敛了,要把激活函数替换为 ReLU 。另外,大家可以把优化器由 SGD 替换为 Adam ,一般情况下,我认为 Adam 效果会比 SGD 好一些。具体原因大家可以自己补课,这里不多说。

    问题2解决方法: 训练时间长还是因为数据量大,这里我采取策略是选择较少的训练样本,猫狗各取了2000个,训练时间会大大缩短。(总体思路还是先保证代码能够跑起来,实际需要的话,再放到服务器上跑)。

    第1步:加载数据集,导入基本库

    # 这个是训练集,猫狗各取了2000个
    ! wget https://gaopursuit.oss-cn-beijing.aliyuncs.com/2021/files/train.zip
    ! unzip train.zip
    # 这个是测试集
    ! wget https://gaopursuit.oss-cn-beijing.aliyuncs.com/202007/dogs_cats_test.zip
    ! unzip dogs_cats_test.zip
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torchvision
    from torchvision import models,transforms,datasets
    import torch.nn.functional as F
    from PIL import Image
    import torch.optim as optim
    import json, random
    import os
    
    # 判断是否存在GPU设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('Using gpu: %s ' % torch.cuda.is_available())
    # 训练图片和测试图片的路径
    train_path = './train/'
    test_path = './test/'
    

    第2步:创建数据集,这里采用了 齐昊 的代码。

    def get_data(file_path):
        file_lst = os.listdir(file_path) #获得所有文件名称 xxxx.jpg
        data_lst = []
        for i in range(len(file_lst)):
            clas = file_lst[i][:3] #cat和dog在文件名的开头
            img_path = os.path.join(file_path,file_lst[i])#将文件名与路径合并得到完整路径,以备读取
            if clas == 'cat':
                data_lst.append((img_path, 0))
            else:
                data_lst.append((img_path, 1))
        return data_lst
    class catdog_set(torch.utils.data.Dataset):
        def __init__(self, path, transform):
            super(catdog_set).__init__()
            self.data_lst = get_data(path)#调用刚才的函数获得数据列表
            self.trans = torchvision.transforms.Compose(transform)
        def __len__(self):
            return len(self.data_lst)
        def __getitem__(self,index):
            (img,cls) = self.data_lst[index]
            image = self.trans(Image.open(img))
            label = torch.tensor(cls,dtype=torch.float32)
            return image,label
    # 将输入图像缩放为 128*128,每一个 batch 中图像数量为128
    # 训练时,每一个 epoch 随机打乱图像的顺序,以实现样本多样化
    train_loader = torch.utils.data.DataLoader(
        catdog_set(train_path, [transforms.Resize((128,128)),transforms.ToTensor()]), 
        batch_size=128, shuffle=True)
    

    第3步:定义网络

    这里是网络的定义,下面附有一段测试代码,因为输入图像是 3x128x128,可以看到网络的处理为:

    3x128x128 ==> conv1( 6x124x124 ==> 6x62x62 )

    ==> conv2( 16x58x58 ==> 16x29x29 )

    ==> conv3( 32x26x26 ==> 32x13x13 )

    ==> conv4( 32x10x10 ==> 32x5x5 )

    ==> conv5( 32x1x1 ==> 32)

    ==> 32 ==> 16 ==> 2

    可以在 forward 函数中加一些 print 函数测试,观察 feature map 形状的变化

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.conv3 = nn.Conv2d(16, 32, 4)
            self.conv4 = nn.Conv2d(32, 32, 4)
            self.conv5 = nn.Conv2d(32, 32, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(32, 16)
            self.fc2 = nn.Linear(16, 2)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = self.pool(F.relu(self.conv4(x)))
            x = F.relu(self.conv5(x))
            x = x.view(-1, 32)
            x = F.relu(self.fc1(x))
            x = F.softmax(self.fc2(x), dim=1)
            return x
    
    # 随机输入,测试网络结构是否通
    # x = torch.randn(1, 3, 128, 128)
    # net = Net()
    # y = net(x)
    # print(y.shape)
    

    第4步:网络训练

    网络训练准备,三个要素:1、将网络放到 GPU 上;2、定义损失函数;3、定义优化器

    # 网络放到GPU上
    net = Net().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    

    开始训练,这里是训练30个 epoch。训练过程中,首先梯度归零,然后正向传播 + 计算损失 + 反向传播 + 优化。所有的网络训练都是这个过程。

    for epoch in range(30):  # 重复多轮训练
        for i, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            # 优化器梯度归零
            optimizer.zero_grad()
            # 正向传播 + 反向传播 + 优化 
            outputs = net(inputs)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step() 
        print('Epoch: %d loss: %.6f' %(epoch + 1, loss.item()))
    print('Finished Training')
    

    Epoch: 1 loss: 0.693591

    Epoch: 2 loss: 0.691867

    Epoch: 3 loss: 0.682191

    Epoch: 4 loss: 0.657503

    Epoch: 5 loss: 0.705427

    Epoch: 6 loss: 0.660462

    Epoch: 7 loss: 0.634173

    Epoch: 8 loss: 0.673097

    Epoch: 9 loss: 0.593748

    Epoch: 10 loss: 0.525426

    Epoch: 11 loss: 0.550043

    Epoch: 12 loss: 0.583569

    Epoch: 13 loss: 0.669047

    Epoch: 14 loss: 0.532821

    Epoch: 15 loss: 0.591107

    Epoch: 16 loss: 0.512886

    Epoch: 17 loss: 0.551087

    Epoch: 18 loss: 0.574038

    Epoch: 19 loss: 0.604391

    Epoch: 20 loss: 0.552344

    Epoch: 21 loss: 0.493089

    Epoch: 22 loss: 0.442594

    Epoch: 23 loss: 0.584947

    Epoch: 24 loss: 0.496618

    Epoch: 25 loss: 0.462232

    Epoch: 26 loss: 0.384848

    Epoch: 27 loss: 0.506766

    Epoch: 28 loss: 0.488330

    Epoch: 29 loss: 0.462068

    Epoch: 30 loss: 0.462078

    Finished Training

    可以看出,如果继续训练更多 epoch,或者使用更多训练样本,网络还可以优化。时间有限,不过多花时间了。

    第5步:测试并输出结果

    测试集包括2000张图片,一张张图片读取,输入网络预测结果,最后将训练结果写入文件。

    resfile = open('res.csv', 'w')
    for i in range(0,2000): 
        img_PIL = Image.open('./test/'+str(i)+'.jpg')
        img_tensor = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor()])(img_PIL)
        img_tensor = img_tensor.reshape(-1, img_tensor.shape[0], img_tensor.shape[1], img_tensor.shape[2])
        img_tensor = img_tensor.to(device)
        out = net(img_tensor).cpu().detach().numpy()
        if out[0, 0] < out[0, 1]:
            resfile.write(str(i)+','+str(1)+'
    ')
        else:
            resfile.write(str(i)+','+str(0)+'
    ')
    resfile.close()
    

    第一次测试,结果是 68.15 。第二次把 batch size 由 64 修改为了 128,准确率上升到了 74.6,也许还有其他方法(调节学习率、网络初始化等),大家自己研究吧。

  • 相关阅读:
    北京之行
    csharp进界
    医院OA系统新思考
    茗洋博客
    monkey主要参数详解
    使用python判断Android自动化的渠道包是否全部打完
    手机连接mac电脑无法使用adb命令解决方法
    Python正则表达式指南
    Mac基本命令大全
    Mac之vim普通命令使用
  • 原文地址:https://www.cnblogs.com/gaopursuit/p/15429484.html
Copyright © 2011-2022 走看看