PyTorch中的数据
Dataset Dataloader transformer
数据集的格式
分类生成标签
制作训练和验证数据的.txt文件
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os
def list_dir(path):
res =dict()
for category in os.listdir(path):
temp_dir = os.path.join(path, category)
if os.path.isdir(temp_dir):
temp =os.listdir(temp_dir)
leaf_file = [os.path.join("/",category,data) for data in temp]
res[category]=leaf_file
return res
def get_text(path,fil_dict):
relation = {"dog":1,"cat":2}
file_nm = os.path.split(path)[-1]+".txt"
with open(os.path.join(path,file_nm),mode="w",encoding="utf-8") as f:
for category_key in fil_dict:
for label_file in fil_dict[category_key]:
labe_res= label_file + " "+ str(relation[category_key] )
print( labe_res )
f.write(labe_res+"
")
if __name__ == '__main__':
data_dir = "./pytorch/data/train"
fil = list_dir(data_dir)
get_text(data_dir,fil)
自定义Dataset
自定义Dataset,继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
# step1: 定义MyDataset类, 继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
class MyDataset(Dataset):
def __init__(self, root_dir, names_file, transform=None):
self.root_dir = root_dir
self.names_file = names_file
self.transform = transform
self.size = 0
self.names_list = []
if not os.path.isfile(self.names_file):
print(self.names_file + ' ## does not exist!')
file = open(self.names_file)
for f in file:
self.names_list.append(f)
self.size += 1
def __len__(self):
return self.size
def __getitem__(self, idx):
image_path = self.root_dir + self.names_list[idx].split(' ')[0]
if not os.path.isfile(image_path):
print(image_path + '@does not exist!')
return None
image = cv2.imread(image_path)
label = int(self.names_list[idx].split(' ')[1])
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
# # 变换Resize
class Resize(object):
def __init__(self, output_size: tuple):
self.output_size = output_size
def __call__(self, sample):
# 图像
image = sample['image']
# 对图像进行缩放
image_new = cv2.resize(image, self.output_size)
return {'image': image_new, 'label': sample['label']}
# # 变换ToTensor
class ToTensor(object):
def __call__(self, sample):
image = sample['image']
image_new = np.transpose(image, (2, 0, 1))
return {'image': torch.from_numpy(image_new),
'label': sample['label']}
if __name__ == "__main__":
train_dataset = MyDataset(root_dir='./pytorch/data/train',
names_file='./pytorch/data/train/train.txt',
transform=transforms.Compose( [
Resize((224,224)),ToTensor()
])
)
for (cnt,i) in enumerate(train_dataset):
image = i['image']
label = i['label']
print(label)
trainset_dataloader = DataLoader(dataset=train_dataset,
batch_size=4,
shuffle=True,
num_workers=4)
for i_batch, sample_batch in enumerate(trainset_dataloader):
images_batch, labels_batch = sample_batch['image'], sample_batch['label']
print(labels_batch.shape,labels_batch.dtype)
print(images_batch.shape,images_batch.dtype)
print(labels_batch)
print(images_batch)
参考
https://pytorch.org/docs/stable/data.html