zoukankan      html  css  js  c++  java
  • 【猫狗数据集】pytorch训练猫狗数据集之创建数据集

    数据集下载地址:

    链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
    提取码:2xq4

    猫狗数据集的分为训练集25000张,在训练集中猫和狗的图像是混在一起的,pytorch读取数据集有两种方式,第一种方式是将不同类别的图片放于其对应的类文件夹中,另一种是实现读取数据集类,该类继承torch.utils.Dataset,并重写__getitem__和__len__。

    先将猫和狗从训练集中区分开来,分别放到dog和cat文件夹下:

    import glob
    import shutil
    import os
    
    #数据集目录
    path = "./ml/dogs-vs-cats/train"
    #训练集目录
    train_path = path+'/train'
    #测试集目录
    test_path = path+'/test'
    
    #将某类图片移动到该类的文件夹下
    def img_to_file(path):
        print("=========开始移动图片============")
        #如果没有dog类和cat类文件夹,则新建
        if not os.path.exists(path+"/dog"):
                os.makedirs(path+"/dog")
        if not os.path.exists(path+"/cat"):
                os.makedirs(path+"/cat")
        print("共:{}张图片".format(len(glob.glob(path+"/*.jpg"))))
        #通过glob遍历到所有的.jpg文件
        for imgPath in glob.glob(path+"/*.jpg"):
            #print(imgPath)
            #使用/划分
            img=imgPath.strip("
    ").replace("\","/").split("/")
            #print(img)
            #将图片移动到指定的文件夹中
            if img[-1].split(".")[0] == "cat":
                shutil.move(imgPath,path+"/cat")
            if img[-1].split(".")[0] == "dog":
                shutil.move(imgPath,path+"/dog")
        print("=========移动图片完成============")    
    img_to_file(train_path)
    print("训练集猫共:{}张图片".format(len(glob.glob(train_path+"/cat/*.jpg"))))
    print("训练集狗共:{}张图片".format(len(glob.glob(train_path+"/dog/*.jpg"))))

    然后从dog中和cat中分别抽取1250张,共2500张图片作为测试集。

    import random
    
    def split_train_test(fileDir,tarDir):
    
            if not os.path.exists(tarDir):
                os.makedirs(tarDir)
            pathDir = os.listdir(fileDir)    #取图片的原始路径
            filenumber=len(pathDir)
            rate=0.1    #自定义抽取图片的比例,比方说100张抽10张,那就是0.1
            picknumber=int(filenumber*rate) #按照rate比例从文件夹中取一定数量图片
            sample = random.sample(pathDir, picknumber)  #随机选取picknumber数量的样本图片
            print("=========开始移动图片============")
            for name in sample:
                    shutil.move(fileDir+name, tarDir+name)
            print("=========移动图片完成============")
    split_train_test(train_path+'/dog/',test_path+'/dog/')  
    split_train_test(train_path+'/cat/',test_path+'/cat/')  

    最终,我们就有以下结构了:

    其中train包含22500张图片,其中dog类和cat类各11250张。test包含2500张图片,其中dog类和cat类各1250张。

    发现测试集还是有点少,那就再来一遍了。

    最后,train包含20250张图片,其中dog类和cat类各10125张。test包含4750张图片,其中dog类和cat类各2375张。

  • 相关阅读:
    看书笔记《python基础》__1
    MQTT
    杂记
    类型转化
    soc
    时钟同步
    设置地址
    清理日志
    pandas_matplot_seaborn
    Qt_Quick开发实战精解_4
  • 原文地址:https://www.cnblogs.com/xiximayou/p/12398285.html
Copyright © 2011-2022 走看看