1.将图片的路径和标签写入csv文件并实现读取
1 # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0 2 def load_csv(self,filename): 3 if not os.path.exists(os.path.join(self.root,filename)): 4 images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断 5 for name in self.name2label.keys(): 6 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别 7 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径 8 images += glob.glob(os.path.join(self.root,name,'*.jpg')) 9 random.shuffle(images) 10 with open(os.path.join(self.root,filename),'w') as f: 11 writer = csv.writer(f) 12 for img in images: 13 name = img.split(os.sep) 14 label = self.name2label[name[-2]] 15 writer.writerow([img,label]) 16 17 # 从csv中读取文件 18 images, labels = [], [] 19 with open(os.path.join(self.root,filename),'r') as f: 20 reader = csv.reader(f) 21 for row in reader: 22 img,label = row 23 label = int(label) 24 images.append(img) 25 labels.append(label) 26 assert len(images) == len(labels) # 保证数据长度一致
return images,labels
2.加载自定义数据集
1 """ 2 自定义数据集 3 image_resize 4 data argumentation(数据增强):Rotate,crop 5 normalize:mean,std 6 ToTensor 7 8 """ 9 import torch 10 import os,glob 11 import random,csv 12 from torch.utils.data import Dataset,DataLoader 13 from torchvision import transforms 14 from PIL import Image 15 import visdom 16 17 18 class Pokemon(Dataset): 19 def __init__(self,root,resize,mode): 20 super(Pokemon,self).__init__() 21 self.root = root 22 self.resize = resize 23 self.name2label = {} 24 for name in os.listdir(os.path.join(root)): #把文件和dir都会加载近来 25 if not sorted(os.path.isdir(os.path.join(root,name))):#排序后,文件夹顺序固定了 26 continue 27 self.name2label[name] = len(self.name2label.keys()) 28 # name2label:{文件夹名,类别编号} 29 # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0 30 self.images, self.labels = self.load_csv('images.csv') 31 # 对数据进行裁剪,mode:train-0.6,validation-0.2,test-0.2数据量是不同的 32 if mode == 'train': 33 self.images = self.images[:,int(len(self.images)*0.6)] 34 self.labels = self.labels[:,int(len(self.images)*0.6)] 35 elif mode == 'val': 36 self.images = self.images[int(len(self.images)*0.6):int(len(self.images)*0.8)] 37 self.labels = self.labels[int(len(self.labels)*0.6):int(len(self.labels)*0.8)] 38 else: 39 self.images = self.images[int(len(self.images) * 0.8):] 40 self.labels = self.labels[int(len(self.labels) * 0.8):] 41 42 def load_csv(self,filename): 43 if not os.path.exists(os.path.join(self.root,filename)): 44 images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断 45 for name in self.name2label.keys(): 46 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别 47 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径 48 images += glob.glob(os.path.join(self.root,name,'*.jpg')) 49 random.shuffle(images) 50 with open(os.path.join(self.root,filename),'w') as f: 51 writer = csv.writer(f) 52 for img in images: 53 name = img.split(os.sep) 54 label = self.name2label[name[-2]] 55 writer.writerow([img,label]) 56 # 从csv中读取文件 57 images, labels = [], [] 58 with open(os.path.join(self.root,filename),'r') as f: 59 reader = csv.reader(f) 60 for row in reader: 61 img,label = row 62 label = int(label) 63 images.append(img) 64 labels.append(label) 65 assert len(images) == len(labels) # 保证数据长度一致 66 return images,labels 67 68 def __len__(self): 69 return len(self.images) 70 71 def __getitem__(self, idx): 72 # idx是[0-len(self.images] 73 # self.images,self.label 74 # img:pokemeon\mew\0001.jpg(这是一个路径)要转变成img数据 75 # label:是数字 76 img, label = self.images[idx], self.labels[idx] 77 tf = transforms.Compose([ 78 lambda x:Image.open(x).convert('RGB'),# string path -> img data 79 transforms.Resize(int(self.resize*1.25), int(self.resize*1.25)), 80 transforms.Randomrotation(15), # 旋转度数 81 transforms.CenterCrop(self.resize),#中心裁剪,保留resize大小 82 transforms.ToTensor(), 83 transforms.Normalize(mean=[0.485,0.456,0.406], 84 std=[0.229,0.224,0.225]) # 归一化之后,范围为-1~1,之前的图片范围为0~1 85 ]) 86 img = tf(img) # 将path转换成数据 87 label = torch.tensor(label) # 将变量label转换成tensor 88 return img,label 89 90 def denormalize(self,x_hat): 91 mean=[0.485,0.456,0.406] 92 std=[0.229,0.224,0.225] 93 # x:[c,h,w] 94 # x_hat = (x-mean)/std 95 # maen[3]->[3,1,1] 96 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) 97 std = torch.tensor(std).unsqueeze(1).unsqueeze(1) 98 x = x_hat * std+mean 99 return x 100 101 def main(): 102 import torchvision 103 vis = visdom.Visdom() 104 """ 105 如果存储比较规范的话,可以使用下面简单的代码加载数据集,文件夹的标签从0开始编码 106 tf = transforms.Compose([ 107 transforms.Resize((64,64)), 108 transforms.ToTensor() 109 ]) 110 db = torchvision.datasets.ImageFolder('./pokemon',transform=tf) 111 loader = DataLoader(db,batch_size=32,shuffle=True) 112 print(db.class_to_idx) #查看类标签 113 114 """ 115 db = Pokemon('./pokemon', 224, 'train') # 根据idx,返回一个 116 x,y = next(iter(db)) 117 print('sample:',x.shape,y.shape) 118 #可视化 119 vis.image(db.denormalize(x),win='sample_x',opts=dict(title = 'sample_x')) 120 # 加载一批 121 loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 ) 122 for x,y in loader: 123 vis.images(db.denormalize(x), nrow=8, win='batch',opts=dict(title='batch')) 124 vis.text(str(y.numpy()),win='label',opts=dict(title='batch-y')) 125 126 127 if __name__ == '__main__': 128 main()
小结:
在加载自定义数据集时,一般步骤
1.定义一个类继承Dataset
2.在类中读取数据集(图片的路径),重写len函数,和getitem函数
在len函数中返回数据集的长度
在getitem函数中,处理一张图片,单个图片路径转换成图片数据(包括transform转换),返回该图片数据和标签
3,将处理好的数据集(均为张量)放入DataLoader中,进行分批
loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )
4.训练时通过enumerate遍历每个batchsize