zoukankan      html  css  js  c++  java
  • COCO数据集转mask

    书接上文,先马克一下,之后再改

    # -*- coding: utf-8 -*-
    """
    Created on Wed Jul  1 14:45:07 2020
    
    @author: mhshao
    """
    from pycocotools.coco import COCO
    import os
    import shutil
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    import cv2
    from PIL import Image, ImageDraw
    import skimage.io as io
    import json
    import numpy as np
    '''
    路径参数
    '''
    #原coco数据集的路径
    dataDir= 'newdata/'
    #用于保存新生成的mask数据的路径
    savepath = "newdata/"
    
    '''
    数据集参数
    '''
    #coco有80类,这里写要进行二值化的类的名字
    #其他没写的会被当做背景变成黑色
    #如我只需要car、bus、truck这三类数据
    classes_names = ['car','bus','truck']  
    #要处理的数据集,比如val2017、train2017等
    #不建议多个数据集在一个list中
    #一次提取一个数据集安全点_(:3」∠❀)_
    datasets_list = ['val2017']
    
    #生成保存路径,函数抄的(›´ω`‹ )
    #if the dir is not exists,make it,else delete it
    def mkr(path):
        if os.path.exists(path):
            shutil.rmtree(path)
            os.mkdir(path)
        else:
            os.mkdir(path)
            
    #生成mask图
    def mask_generator(coco,width,height,anns_list):
        mask_pic = np.zeros((height, width))
        #生成mask
        for single in anns_list:
                mask_single = coco.annToMask(single)
                mask_pic += mask_single
        #转化为255
        for row in range(height):
                for col in range(width):
                    if (mask_pic[row][col] > 0):
                        mask_pic[row][col] = 255
        mask_pic = mask_pic.astype(int)
        '''
        #转为三通道
        imgs = np.zeros(shape=(height, width, 3), dtype=np.float32)
        imgs[:, :, 0] = mask_pic[:, :]
        imgs[:, :, 1] = mask_pic[:, :]
        imgs[:, :, 2] = mask_pic[:, :]
        imgs = imgs.astype(int)
        '''    
        return mask_pic
        
    #处理json数据并保存二值mask
    def get_mask_data(annFile,mask_to_save):
        #获取COCO_json的数据
        coco = COCO(annFile)
        #拿到所有需要的图片数据的id
        classes_ids = coco.getCatIds(catNms = classes_names)
        #取所有类别的并集的所有图片id
        #如果想要交集,不需要循环,直接把所有类别作为参数输入,即可得到所有类别都包含的图片
        imgIds_list = []
        for idx in classes_ids:
            imgidx = coco.getImgIds(catIds=idx)
            imgIds_list += imgidx
        #去除重复的图片
        imgIds_list = list(set(imgIds_list))
        
        #一次性获取所有图像的信息
        image_info_list = coco.loadImgs(imgIds_list)
        
        #对每张图片生成一个mask
        for imageinfo in image_info_list:
            #获取对应类别的分割信息
            annIds = coco.getAnnIds(imgIds = imageinfo['id'], catIds = classes_ids, iscrowd=None)
            anns_list = coco.loadAnns(annIds)
            #生成二值mask图
            mask_image = mask_generator(coco,imageinfo['width'],imageinfo['height'],anns_list)
            #保存图片
            file_name = mask_to_save + '/' +imageinfo['file_name'][:-4]+'.png'
            plt.imsave(file_name , mask_imageif __name__ == '__main__':
        #按单个数据集进行处理
        for dataset in datasets_list:
            #用来保存最后生成的mask图像目录
            mask_to_save = savepath + 'masks/' + dataset
            mkr(savepath + 'masks/')
            #生成路径
            mkr(mask_to_save)
    
            #获取要处理的json文件路径
            #我这里用了之前自己生成的部分类别json
            #具体方法见我前一篇博客
            annFile='{}/annotations/instances_{}_sub.json'.format(dataDir,dataset)
            #处理数据
            get_mask_data(annFile,mask_to_save)
            print('Got all the masks of {} from {} ٩( ๑╹ ꇴ ╹)۶'.format(classes_names,dataset))

    000000001532.png

     

    000000097924.png

    000000121242.png

  • 相关阅读:
    LeetCode 172:阶乘后的零
    Ubuntu12.04更新出现 The system is running in low-graphics mode解决方法
    不加参数的存储过程
    PCC-S-02201, Encountered the symbol "DB_USER_OPER_COUNT"
    该思考
    关于export环境变量生存期
    会话临时表 ORA-14452
    如何创建守护进程--及相关概念
    2014年10月末
    6个月
  • 原文地址:https://www.cnblogs.com/lhdb/p/13221302.html
Copyright © 2011-2022 走看看