zoukankan      html  css  js  c++  java
  • Pytorch自定义数据库

    1)前言

    虽然torchvision.datasets中已经封装了好多通用的数据集,但是我们在使用Pytorch做深度学习任务的时候,会面临着自定义数据库来满足自己的任务需要。如我们要训练一个人脸关键点检测算法,提供的训练数据标注如下形式,存在CSV文件中:

    image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
    0805personali01.jpg,27,83,27,98, ... 84,134
    1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

    在本次教程中,我们需要用到两个额外的包:

    • scikit-image: 用于图片io转换
    • pandas: 用于解析csv文件

    首先学习如何使用pandas库解析csv文件

    import pandas as pd
    landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv') n = 65 img_name = landmarks_frame.iloc[n, 0] landmarks = landmarks_frame.iloc[n, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name)) print('Landmarks shape: {}'.format(landmarks.shape)) print('First 4 Landmarks: {}'.format(landmarks[:4]))

    2)自定义数据库

    torch.utils.data.Dataset是一个表示数据库的抽象类,自定义数据库需要继承这个类,并且重写其以下方法:

    __len__ :返回数据库的大小.
    __getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本

    以下创建人脸特征点检测的数据库。我们将在__init__中解析csv文件,而在__getitem__中读取图片。这样可以在需要图片是才加载,内存效率高。此外,我们还可以先将数据集封装成lmdb数据库,读取速度更快。

    import torch.utils.data.Dataset as Dataset
    class FaceLandmarksDataset(Dataset):
        """Face Landmarks dataset."""
    
        def __init__(self, csv_file, root_dir, transform=None):
            """
            Args:
                csv_file (string): 到达标注文件cvs的路径.
                root_dir (string): 所有图片的根目录.
                transform (callable, optional): (可选参数)对每一个样本进行转换.
            """
            self.landmarks_frame = pd.read_csv(csv_file)
            self.root_dir = root_dir
            self.transform = transform
    
        def __len__(self):
            return len(self.landmarks_frame)
    
        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0]) #第idx条数据的第一个字段,即文件名称
            image = io.imread(img_name)                           #读取图像数据
            landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() #读取第idx条数据的第二个字段及其之后的所有字段,即所有关键点的坐标。然后转成矩阵形式
            landmarks = landmarks.astype('float').reshape(-1, 2)  #将矩阵reshape成n行两列矩阵
            sample = {'image': image, 'landmarks': landmarks}     #封装数据
    
            if self.transform:
                sample = self.transform(sample)                   #数据转换
    
            return sample                                         #返回数据

    注:__getitem__每次只返回一个条数据,至于batch的封装可以在DataLoader中设置batchsize,至于读取速度可以设置num_worker。

  • 相关阅读:
    mysql_fetch_row()获取显示数据
    数组上下移动
    croppie 在Angular8 中使用
    关于 element 的 backToTop
    苹果手机new Date()问题
    js精简代码集合
    vue 中使用高德地图, 地图选点
    代替if else 的表单验证方法!
    记一次webpack打包样式加载问题
    echarts 饼图的指示线(labelline) 问题
  • 原文地址:https://www.cnblogs.com/houjun/p/10405466.html
Copyright © 2011-2022 走看看