数据集下载地址:
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
提取码:2xq4
之前在:https://www.cnblogs.com/xiximayou/p/12398285.html创建好了数据集,将它上传到谷歌colab
在colab上的目录如下:
在utils中的rdata.py定义了读取该数据集的代码:
from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import torch #预处理 transform = transforms.Compose([transforms.ToTensor()]) path = "/content/drive/My Drive/colab notebooks/data/dogcat" train_path=path+"/train" test_path=path+"/test" #使用torchvision.datasets.ImageFolder读取数据集指定train和test文件夹 train_data = torchvision.datasets.ImageFolder(train_path, transform=transform) train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=1) test_data = torchvision.datasets.ImageFolder(test_path, transform=transform) test_loader = DataLoader(test_data, batch_size=32, shuffle=True, num_workers=1) print(train_data.classes) #根据分的文件夹的名字来确定的类别 print(train_data.class_to_idx) #按顺序为这些类别定义索引为0,1... print(train_data.imgs) #返回从所有文件夹中得到的图片的路径以及其类别 print(test_data.classes) #根据分的文件夹的名字来确定的类别 print(test_data.class_to_idx) #按顺序为这些类别定义索引为0,1... print(test_data.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
ImageFolder可以读取我们的train或test下面的文件夹,并为每一个标签进行编码,同时将图片与标签进行对应。
在test.ipynb中运行rdata.py
说明我们创建的数据集是可以用的了。
有了数据集,接下来就是网络的搭建以及训练和测试了。