zoukankan      html  css  js  c++  java
  • pytorch实践:dog VS cat

    猫狗分类,练手级代码,与手写数字识别相比,主要修改的地方是输出全连接层,将输出通道由10(十个数字)改成2(猫狗二分类)。还有一个是对数据集处理,因pytorch没有内置数据集函数,因此图片要自己处理。

    数据要用opencv处理,归一化。

    数据集:data __train__Cat

          |     |__Dog

          |__test__Cat

             |__Dog

    get_data.py

    import os
    import cv2
    import time
    from torchvision import transforms
    import torch
    trans=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((.5,.5,.5),(.5,.5,.5))
        ]
    )
    DATA_PATH = './data/'
    PIC_SIZE = 32
    
    
    def get_files():
    
        train_data = []
        test_data =  []
        train_cat_path = DATA_PATH + 'train/Cat/'
        train_dog_path = DATA_PATH + 'train/Dog/'
        test_cat_path = DATA_PATH + 'test/Cat/'
        test_dog_path = DATA_PATH + 'test/Dog/'
    
        print('now,loading data.due to The amount of data is huge,you have to wait minutes')
        start_time=temp_time=time.time()
    
        for file in os.listdir(train_cat_path):
            image=cv2.imread(train_cat_path+file)
            try:
                image=cv2.resize(image, (PIC_SIZE, PIC_SIZE))
                train_data.append([image,0])
            except BaseException:
                os.remove(train_cat_path+file)
                # print('无效的图片:%s' % file)
            finally:
                if time.time()-temp_time > 20:
                    temp_time=time.time()
                    print('Take %d seconds'%(time.time()-start_time))
    
    
    
        for file in os.listdir(train_dog_path):
            image = cv2.imread(train_dog_path + file)
            try:
                image=cv2.resize(image, (PIC_SIZE, PIC_SIZE))
                train_data.append([image,1])
    
            except BaseException:
                os.remove(train_dog_path + file)
                # print('无效的图片:%s' % file)
            finally:
                if time.time() - temp_time > 20:
                    temp_time = time.time()
                    print('Take %d seconds' % (time.time() - start_time))
    
    
        for file in os.listdir(test_cat_path):
            image = cv2.imread(test_cat_path + file)
            try:
                image = cv2.resize(image, (PIC_SIZE, PIC_SIZE))
                test_data.append([image,0])
    
            except BaseException:
                os.remove(test_cat_path + file)
                # print('无效的图片:%s' % file)
            finally:
                if time.time() - temp_time > 20:
                    temp_time = time.time()
                    print('Take %d seconds' % (time.time() - start_time))
    
        for file in os.listdir(test_dog_path):
            image = cv2.imread(test_dog_path + file)
            try:
                image = cv2.resize(image, (PIC_SIZE, PIC_SIZE))
                test_data.append([image,1])
    
            except BaseException:
                os.remove(test_dog_path + file)
                # print('无效的图片:%s' % file)
            finally:
                if time.time() - temp_time > 20:
                    temp_time = time.time()
                    print('Take %d seconds' % (time.time() - start_time))
    
        for img in train_data:
            img[0]=trans(img[0])
    
    
        for img in test_data:
            img[0]=trans(img[0])
    
        print('have loaded the data:
    There are %d train_data
    There are %d test_data' %(len(train_data), len(test_data)))
        print('-----------------------------------------------------------------------------')
    
        return train_data,test_data
    
    if __name__ == '__main__':
        torch.save(get_files(),"data.pyd")

    将数据集写到data.pyd

    然后训练,测试。

    dogVScat.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import numpy as np
    
    LR = 0.01
    MOM = 0.5
    EPOCHES=100
    BATCHSIZE=50
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
    
            self.conv1 = nn.Conv2d(in_channels=3,out_channels=10,kernel_size=3)
            self.conv2 = nn.Conv2d(10,20,3)
            self.conv3 = nn.Conv2d(20,10,3)
    
            self.mp = nn.MaxPool2d(2)
            self.fc = nn.Linear(40,2)
    
        def forward(self,x):
            in_size = x.size(0)
    
            x = F.relu(self.mp(self.conv1(x)))
    
            x = F.relu(self.mp(self.conv2(x)))
    
            x = F.relu(self.mp(self.conv3(x)))
    
            x = x.view(in_size,-1)
    
            x = self.fc(x)
    
            return F.log_softmax(x,dim=1)
    
    def train():
    
        xbatch = []
        ybatch = []
    
        for i, (x, y) in enumerate(train_data):
            xbatch.append(x)
            ybatch.append(y)
    
            if (i+1) % BATCHSIZE == 0:
    
                xbatch = torch.stack(xbatch)  #convert list of tensor into tensor
    
                ybatch = torch.Tensor(ybatch).long()
    
                out = model(xbatch)
    
                loss =  F.nll_loss(out, ybatch)
    
                xbatch = []
                ybatch = []
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        # print(str(epoch)+" epoch has Completed training")
        # torch.save(model,str(epoch)+".pkl")
    
    def test(epoch):
    
        test_loss = 0
        correct = 0
    
        xbatch = []
        ybatch = []
        for i,(x,y) in enumerate(test_data):
            xbatch.append(x)
            ybatch.append(y)
    
            if (i+1) % BATCHSIZE == 0:
    
                xbatch = torch.stack(xbatch)  #convert list of tensor into tensor
                ybatch = torch.Tensor(ybatch).long()
    
                output = model(xbatch)
    
                pred=torch.max(output,1)[1]
    
                correct +=pred.eq(ybatch).sum(0).numpy()
    
    
                # test_loss += F.nll_loss(output, ybatch).data[0]
                xbatch = []
                ybatch = []
    
        print('correct of epoch {} is {:.2f}%'.format(epoch,correct/len(test_data)*100))
    
    if __name__ == '__main__':
        model = Net()
        optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOM)
        train_data, test_data = torch.load("data.pyd")
        np.random.shuffle(train_data)
        for epoch in range(EPOCHES):
            train()
            test(epoch)

    训练结果:

    correct of epoch 0 is 52.33%
    correct of epoch 1 is 54.84%
    correct of epoch 2 is 55.95%
    correct of epoch 3 is 56.59%
    correct of epoch 4 is 57.57%
    correct of epoch 5 is 60.50%
    correct of epoch 6 is 62.18%
    correct of epoch 7 is 63.81%
    correct of epoch 8 is 64.46%
    correct of epoch 9 is 65.24%
    correct of epoch 10 is 65.93%
    correct of epoch 11 is 66.55%
    correct of epoch 12 is 67.47%
    correct of epoch 13 is 68.45%
    correct of epoch 14 is 69.00%
    correct of epoch 15 is 69.62%
    correct of epoch 16 is 69.99%
    correct of epoch 17 is 70.58%
    correct of epoch 18 is 71.10%
    correct of epoch 19 is 71.42%
    correct of epoch 20 is 71.87%
    correct of epoch 21 is 72.31%
    correct of epoch 22 is 72.36%
    correct of epoch 23 is 72.76%
    correct of epoch 24 is 73.01%
    correct of epoch 25 is 73.32%
    correct of epoch 26 is 73.36%
    correct of epoch 27 is 73.51%
    correct of epoch 28 is 73.17%
    correct of epoch 29 is 73.38%
    correct of epoch 30 is 73.50%
    correct of epoch 31 is 73.73%
    correct of epoch 32 is 73.93%
    correct of epoch 33 is 74.15%
    correct of epoch 34 is 74.11%
    correct of epoch 35 is 74.22%
    correct of epoch 36 is 74.26%
    correct of epoch 37 is 74.07%
    correct of epoch 38 is 74.12%
    correct of epoch 39 is 74.35%
    correct of epoch 40 is 74.38%
    correct of epoch 41 is 74.44%
    correct of epoch 42 is 74.17%
    correct of epoch 43 is 74.19%
    correct of epoch 44 is 74.30%
    correct of epoch 45 is 74.61%
    correct of epoch 46 is 74.64%
    correct of epoch 47 is 74.54%
    correct of epoch 48 is 74.58%
    correct of epoch 49 is 74.59%
    correct of epoch 50 is 74.59%
    correct of epoch 51 is 74.53%
    correct of epoch 52 is 74.45%
    correct of epoch 53 is 74.43%
    correct of epoch 54 is 74.43%
    correct of epoch 55 is 74.41%
    correct of epoch 56 is 74.42%
    correct of epoch 57 is 74.52%
    correct of epoch 58 is 74.48%
    correct of epoch 59 is 74.34%
    correct of epoch 60 is 74.21%
    correct of epoch 61 is 74.16%
    correct of epoch 62 is 74.15%
    correct of epoch 63 is 74.25%
    correct of epoch 64 is 74.11%
    correct of epoch 65 is 73.95%
    correct of epoch 66 is 73.85%
    correct of epoch 67 is 73.99%
    correct of epoch 68 is 74.15%
    correct of epoch 69 is 74.05%
    correct of epoch 70 is 74.05%
    correct of epoch 71 is 74.34%
    correct of epoch 72 is 74.21%
    correct of epoch 73 is 74.14%
    correct of epoch 74 is 73.98%
    correct of epoch 75 is 73.87%
    correct of epoch 76 is 73.88%
    correct of epoch 77 is 73.85%
    correct of epoch 78 is 73.84%
    correct of epoch 79 is 73.84%
    correct of epoch 80 is 73.65%
    correct of epoch 81 is 73.66%
    correct of epoch 82 is 73.43%
    correct of epoch 83 is 73.36%
    correct of epoch 84 is 73.30%
    correct of epoch 85 is 73.12%
    correct of epoch 86 is 73.20%
    correct of epoch 87 is 73.22%
    correct of epoch 88 is 73.13%
    correct of epoch 89 is 73.16%
    correct of epoch 90 is 73.17%
    correct of epoch 91 is 72.99%
    correct of epoch 92 is 73.09%
    correct of epoch 93 is 73.02%
    correct of epoch 94 is 72.80%
    correct of epoch 95 is 72.98%
    correct of epoch 96 is 72.73%
    correct of epoch 97 is 72.80%
    correct of epoch 98 is 72.76%
    correct of epoch 99 is 72.68%

    最高准确率为74.6%

  • 相关阅读:
    Python split()方法分割字符串
    Python创建线程
    Python find()方法
    webpack中‘vant’全局引入和按需引入【vue-cli】
    webpack中‘mint-ui’全局引入和按需引入【vue-cli】
    nginx中 处理post方式打开页面的报错405
    nginx中 vue路由去掉#后的配置问题
    webpack中 VUE使用搜狐ip库查询设备ip地址
    webpack中 VUE使用百度地图获取地理位置
    VUE动态设置网页head中的title
  • 原文地址:https://www.cnblogs.com/jiangnanyanyuchen/p/9782283.html
Copyright © 2011-2022 走看看