zoukankan      html  css  js  c++  java
  • PyTorch 自定义数据集

    准备数据

    准备 COCO128 数据集,其是 COCO train2017 前 128 个数据。按 YOLOv5 组织的目录:

    $ tree ~/datasets/coco128 -L 2
    /home/john/datasets/coco128
    ├── images
    │   └── train2017
    │       ├── ...
    │       └── 000000000650.jpg
    ├── labels
    │   └── train2017
    │       ├── ...
    │       └── 000000000650.txt
    ├── LICENSE
    └── README.txt
    

    详见 Train Custom Data

    定义 Dataset

    torch.utils.data.Dataset 是一个数据集的抽象类。自定义数据集时,需继承 Dataset 并覆盖如下方法:

    • __len__: len(dataset) 获取数据集大小。
    • __getitem__: dataset[i] 访问第 i 个数据。

    详见:

    自定义实现 YOLOv5 数据集的例子:

    import os
    from pathlib import Path
    from typing import Any, Callable, Optional, Tuple
    
    import numpy as np
    import torch
    import torchvision
    from PIL import Image
    
    
    class YOLOv5(torchvision.datasets.vision.VisionDataset):
    
      def __init__(
        self,
        root: str,
        name: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
      ) -> None:
        super(YOLOv5, self).__init__(root, transforms, transform, target_transform)
        images_dir = Path(root) / 'images' / name
        labels_dir = Path(root) / 'labels' / name
        self.images = [n for n in images_dir.iterdir()]
        self.labels = []
        for image in self.images:
          base, _ = os.path.splitext(os.path.basename(image))
          label = labels_dir / f'{base}.txt'
          self.labels.append(label if label.exists() else None)
    
      def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        img = Image.open(self.images[idx]).convert('RGB')
    
        label_file = self.labels[idx]
        if label_file is not None:  # found
          with open(label_file, 'r') as f:
            labels = [x.split() for x in f.read().strip().splitlines()]
            labels = np.array(labels, dtype=np.float32)
        else:  # missing
          labels = np.zeros((0, 5), dtype=np.float32)
    
        boxes = []
        classes = []
        for label in labels:
          x, y, w, h = label[1:]
          boxes.append([
            (x - w/2) * img.width,
            (y - h/2) * img.height,
            (x + w/2) * img.width,
            (y + h/2) * img.height])
          classes.append(label[0])
    
        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(classes, dtype=torch.int64)
    
        if self.transforms is not None:
          img, target = self.transforms(img, target)
    
        return img, target
    
      def __len__(self) -> int:
        return len(self.images)
    

    以上实现,继承了 VisionDataset 子类。其 __getitem__ 返回了:

    • image: PIL Image, 大小为 (H, W)
    • target: dict, 含以下字段:
      • boxes (FloatTensor[N, 4]): 真实标注框 [x1, y1, x2, y2], x 范围 [0,W], y 范围 [0,H]
      • labels (Int64Tensor[N]): 上述标注框的类别标识

    读取 Dataset

    dataset = YOLOv5(Path.home() / 'datasets/coco128', 'train2017')
    print(f'dataset: {len(dataset)}')
    print(f'dataset[0]: {dataset[0]}')
    

    输出:

    dataset: 128
    dataset[0]: (<PIL.Image.Image image mode=RGB size=640x480 at 0x7F6F9464ADF0>, {'boxes': tensor([[249.7296, 200.5402, 460.5399, 249.1901],
            [448.1702, 363.7198, 471.1501, 406.2300],
            ...
            [  0.0000, 188.8901, 172.6400, 280.9003]]), 'labels': tensor([44, 51, 51, 51, 51, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45,
            45, 50, 50, 50, 51, 51, 60, 42, 44, 45, 45, 45, 50, 51, 51, 51, 51, 51,
            51, 44, 50, 50, 50, 45])})
    

    预览:

    使用 DataLoader

    训练需要批量提取数据,可以使用 DataLoader :

    dataset = YOLOv5(Path.home() / 'datasets/coco128', 'train2017',
      transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
      ]))
    
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True,
                            collate_fn=lambda batch: tuple(zip(*batch)))
    
    for batch_i, (images, targets) in enumerate(dataloader):
      print(f'batch {batch_i}, images {len(images)}, targets {len(targets)}')
      print(f'  images[0]: shape={images[0].shape}')
      print(f'  targets[0]: {targets[0]}')
    

    输出:

    batch 0, images 64, targets 64
      images[0]: shape=torch.Size([3, 480, 640])
      targets[0]: {'boxes': tensor([[249.7296, 200.5402, 460.5399, 249.1901],
            [448.1702, 363.7198, 471.1501, 406.2300],
            ...
            [  0.0000, 188.8901, 172.6400, 280.9003]]), 'labels': tensor([44, 51, 51, 51, 51, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45,
            45, 50, 50, 50, 51, 51, 60, 42, 44, 45, 45, 45, 50, 51, 51, 51, 51, 51,
            51, 44, 50, 50, 50, 45])}
    batch 1, images 64, targets 64
      images[0]: shape=torch.Size([3, 248, 640])
      targets[0]: {'boxes': tensor([[337.9299, 167.8500, 378.6999, 191.3100],
            [383.5398, 148.4501, 452.6598, 191.4701],
            [467.9299, 149.9001, 540.8099, 193.2401],
            [196.3898, 142.7200, 271.6896, 190.0999],
            [134.3901, 154.5799, 193.9299, 189.1699],
            [ 89.5299, 162.1901, 124.3798, 188.3301],
            [  1.6701, 154.9299,  56.8400, 188.3700]]), 'labels': tensor([20, 20, 20, 20, 20, 20, 20])}
    

    源码

    参考

    APIs:

    GoCoding 个人实践的经验分享,可关注公众号!

  • 相关阅读:
    POJ 1887 Testing the CATCHER
    HDU 3374 String Problem
    HDU 2609 How many
    POJ 1509 Glass Beads
    POJ 1458 Common Subsequence
    POJ 1159 Palindrome
    POJ 1056 IMMEDIATE DECODABILITY
    POJ 3080 Blue Jeans
    POJ 1200 Crazy Search
    软件体系结构的艺术阅读笔记1
  • 原文地址:https://www.cnblogs.com/gocodinginmyway/p/14439879.html
Copyright © 2011-2022 走看看