zoukankan      html  css  js  c++  java
  • Pytorch_COCO数据集_dataset

    Coco数据集

    本文主要内容来源于pytorch加载自己的coco数据集,针对其内容做学习和理解,进一步加深对数据集的理解以及自己的数据到dataset的步骤。仅作学习用
     了解输入和输出
    

    代码示例

    #!/usr/bin/env python3
    # -*- coding: UTF-8 -*-
    
    import os
    import os.path
    import json
    import cv2
    import numpy as np
    import torch
    from torch.utils.data import Dataset
    from torch.utils.data import TensorDataset
    from torchvision.transforms import functional as F
    
    
    # step1: 定义 CoCo_DataSet 类, 继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
    class CoCo_DataSet(Dataset):
        def __init__(self, coco_root_dir,transforms,train_set=True):
            self.transforms = transforms
            self.annotations_root = os.path.join(coco_root_dir,"annotations")
            if train_set:
                self.annotations_json = os.path.join(self.annotations_root,"coco_instance_train.json")
                self.image_root = os.path.join(coco_root_dir,"images","train2021")
            else:
                self.annotations_json = os.path.join(self.annotations_root,"coco_instance_val.json")
                self.image_root = os.path.join(coco_root_dir,"images","val2021")
            #判断文件是否存在
            assert os.path.exists(self.annotations_json), "{} file not exist ".format(self.annotations_json)
            if not os.path.isfile(self.annotations_json):
                print(self.annotations_json + ' ## not a file!')
            #读取Json文件
            json_file = open(file=self.annotations_json,mode='r',encoding="utf8")
            self.coco_dict = json.load(json_file)
            self.bbox_image= {}
            bbox_img = self.coco_dict["annotations"]
            for tmp in bbox_img:
                tmp_append  = list()
                pict_id = tmp["image_id"]
                pict_id = pict_id -1
                class_id = tmp["category_id"]
                bbox = tmp["bbox"]
                tmp_append.append(class_id)
                tmp_append.append(bbox)
                if self.bbox_image.__contains__(pict_id):
                    self.bbox_image[pict_id].append(tmp_append)
                else:
                    self.bbox_image[pict_id] =[]
                    self.bbox_image[pict_id].append(tmp_append)
    
    
        def __len__(self):
            return len(self.coco_dict["images"])
    
        def __getitem__(self,idx):
            image_list = self.coco_dict["images"]
            pict_name = image_list[idx]["file_name"]
            pict_path = os.path.join(self.image_root,pict_name)
            if not os.path.isfile(pict_path):
                print(pict_path +  '@does not exist!')
                return None
            image = cv2.imread(pict_path)
            labels =[]
            bboxes = []
            target = {}
            if self.bbox_image.__contains__(idx):
                for img_annoatations in self.bbox_image[idx]:
                    # (class_id) (bbox)
                    bboxes.append(img_annoatations[1])
                    labels.append(img_annoatations[0])
                bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
                labels = torch.as_tensor(labels,dtype=torch.int64)
                target["bboxes"]= bboxes
                target["labels"]= labels
            else:
                bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
                labels = torch.as_tensor(labels,dtype=torch.int64)
                target["bboxes"]= bboxes
                target["labels"]= labels
            if self.transforms is not None:
                image,target = self.transforms(image,target)
            return image,target
    
        def collate_fn(self,batch):
            return tuple(zip(*batch))
    
    
    
    class Compose():
        def __init__(self,transforms):
            self.transforms = transforms
    
        def __call__(self,image,target):
            for t in self.transforms:
                image,target = t(image,target)
            return image,target
    
    class ToTensor(object):
        def __call__(self, image,target):
            image =F.to_tensor(image)
            return image,target
    # # 变换Resize
    class Resize(object):
    
        def __init__(self, output_size: tuple):
            self.output_size = output_size
    
        def __call__(self, sample):
            # 图像
            image = sample['image']
            # 对图像进行缩放
            image_new =  cv2.resize(image, self.output_size)
            return {'image': image_new, 'label': sample['label']}
    
    # # 变换ToTensor
    class MyToTensor(object):
        def __call__(self, sample):
            image = sample['image']
            image_new = np.transpose(image, (2, 0, 1))
            return {'image': torch.from_numpy(image_new),
                    'label': sample['label']}
    
    if __name__ =="__main__":
        data_transform={
            "train": Compose([ToTensor()]),
            "val":Compose([ToTensor()])
        }
        coco_root_path= r"D:\data\dataset\coco"
        mycocoDataset = CoCo_DataSet(coco_root_path,data_transform["train"])
        dataloader = torch.utils.data.DataLoader(mycocoDataset, batch_size=2, shuffle=True,collate_fn=mycocoDataset.collate_fn)
        # dataloader = torch.utils.data.DataLoader(mycocoDataset, batch_size=2, shuffle=True,collate_fn=mycocoDataset.collate_fn)
        for i_batch, sample_batch in enumerate(dataloader):
            # print(type(sample_batch))
            # print(len(sample_batch))
            # print(len(sample_batch[0]))
            # print(len(sample_batch[1]))
            images_batch, labels_batch = sample_batch[0][0], sample_batch[0][1]
            # bboxes  labels
            #images_batch, labels_batch = sample_batch[1][0], sample_batch[1][1]
            print(images_batch)
            print(labels_batch)
            # print(labels_batch.shape,labels_batch.dtype)
            # print(images_batch.shape,images_batch.dtype)
            # print(labels_batch)
    

    语法说明

     1.python3  判断字典中是否存在某个键 -例如arr_dict 是字典,判断"int_key" 是否
        01.函数 arr_dict.__contains__("int_key")
    
        02.使用 in 方法
         if "int_key" in arr_dict:
             print("存在")
      2. mycocoDataset.__getitem__(1) 返回的数据是
      (image-tensor,{"bboxes":tensor,"labels":tensor }) 
    

    参考:

     深度网络学习-PyTorch_自定义Datsset  https://www.cnblogs.com/ytwang/p/15239433.html
     pytorch加载自己的coco数据集 https://blog.csdn.net/yangyangne/article/details/120384069 
     DATASETS & DATALOADERS  https://pytorch.org/tutorials/beginner/basics/data_tutorial.html  
     目标检测系列一:如何制作数据集?  http://www.spytensor.com/index.php/archives/48/
  • 相关阅读:
    PATA 1071 Speech Patterns.
    PATA 1027 Colors In Mars
    PATB 1038. 统计同成绩学生(20)
    1036. 跟奥巴马一起编程(15)
    PATA 1036. Boys vs Girls (25)
    PATA 1006. Sign In and Sign Out (25)
    读取web工程目录之外的图片并显示
    DOS命令
    java连接oracle集群
    servlet
  • 原文地址:https://www.cnblogs.com/ytwang/p/15753180.html
Copyright © 2011-2022 走看看