zoukankan      html  css  js  c++  java
  • CrowdHuman数据集标注格式转换为YOLOv3可以使用的COCO格式

    需要了解CrowdHuman的数据标注格式odgt,YOLOv3需要的COCO格式(不需要使用json文件,只需要图片位置信息和标注信息)

    YOLOv3 github地址:https://github.com/eriklindernoren/PyTorch-YOLOv3

    保存每一张图片的位置信息

     1 import os
     2 import json
     3 
     4 
     5 def load_file(fpath):  # fpath是具体的文件 ,作用:#str to list
     6     assert os.path.exists(fpath)  # assert() raise-if-not
     7     with open(fpath, 'r') as fid:
     8         lines = fid.readlines()
     9     records = [json.loads(line.strip('
    ')) for line in lines]  # str to list
    10     return records
    11 
    12 
    13 def img2txt(odgtpath, respath):
    14     records = load_file(odgtpath)  # 提取odgt文件数据
    15     record_list = len(records)  # 获得record的长度,循环遍历所有数据。
    16     print(os.getcwd())
    17     # os.mkdir(os.getcwd() + respath)
    18     with open(respath, 'w') as txt:
    19         for i in range(record_list):
    20             file_name = records[i]['ID'] + '.jpg'
    21             file_name = str("/datasets/crowdhuman/images/val/Image/" + file_name)
    22             txt.write(file_name + '
    ')
    23 
    24 
    25 if __name__ == '__main__':
    26     odgtpath = "/datasets/crowdhuman/annotation_val.odgt"
    27     respath = "/datasets/crowdhuman/val_name.txt"
    28     img2txt(odgtpath, respath)

    保存每一张图片标注信息中的全身坐标fbox

     1 import time
     2 import img2txt
     3 from PIL import Image
     4 
     5 
     6 def tonormlabel(odgtpath, storepath):
     7     records = img2txt.load_file(odgtpath)
     8     record_list = len(records)
     9     print(record_list)
    10     categories = {}
    11     # txt = open(respath, 'w')
    12     for i in range(record_list):
    13         txt_name = storepath + records[i]['ID'] + '.txt'
    14         file_name = records[i]['ID'] + '.jpg'
    15         #print(i)
    16         im = Image.open("/datasets/crowdhuman/images/train_all/Image/" + file_name)
    17         height = im.size[1]
    18         width = im.size[0]
    19         file = open(txt_name, 'w')
    20         gt_box = records[i]['gtboxes']
    21         gt_box_len = len(gt_box)  # 每一个字典gtboxes里,也有好几个记录,分别提取记录。
    22         for j in range(gt_box_len):
    23             category = gt_box[j]['tag']
    24             if category not in categories:  # 该类型不在categories,就添加上去
    25                 new_id = len(categories) + 1  # ID递增
    26                 categories[category] = new_id
    27             category_id = categories[category]  # 重新获取它的类别ID
    28             fbox = gt_box[j]['fbox']  # 获得全身框
    29             norm_x = fbox[0] / width
    30             norm_y = fbox[1] / height
    31             norm_w = fbox[2] / width
    32             norm_h = fbox[3] / height
    33             '''
    34             norm_x = 0 if norm_x <= 0 else norm_x
    35             norm_x = 1 if norm_x >= 1 else norm_x
    36             norm_y = 0 if norm_y <= 0 else norm_y
    37             norm_y = 1 if norm_y >= 1 else norm_y
    38             norm_w = 0 if norm_w <= 0 else norm_w
    39             norm_w = 1 if norm_w >= 1 else norm_w
    40             norm_h = 0 if norm_h <= 0 else norm_h
    41             norm_h = 1 if norm_h >= 1 else norm_h
    42             '''
    43             blank = ' '
    44             if j == gt_box_len-1:
    45                 file.write(str(category_id - 1) + blank + '{:.6f}'.format(norm_x) + blank + '{:.6f}'.format(norm_y) + blank
    46                            + '{:.6f}'.format(norm_w) + blank + '{:.6f}'.format(norm_h))
    47             else:
    48                 file.write(str(category_id - 1) + blank + '{:.6f}'.format(norm_x) + blank + '{:.6f}'.format(norm_y) + blank
    49                            + '{:.6f}'.format(norm_w) + blank + '{:.6f}'.format(norm_h) + '
    ')
    50 
    51 
    52 if __name__ == '__main__':
    53     odgtpath = "/datasets/crowdhuman/annotation_train.odgt"  
    54     storepath = "/datasets/crowdhuman/labels/train_all/Image/"
    55     print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))  # 格式化输出时间
    56     start = time.time()
    57     tonormlabel(odgtpath, storepath)
    58     end = time.time()
    59     print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
    60     print('已完成转换,共耗时{:.5f}s'.format(end - start))
  • 相关阅读:
    201671010146 2017-2 <表格监督>
    201671010146 2017-2 《Java学期末有感》
    201671010146 2017-2 《Java线程》
    201671010146 2017-2 《第十六周学习Java有感》
    201671010146 2017―2 《第16周学习java有感》
    201671010146 2017―2 《第十五周学习java有感》
    201671010146 2017-2 《java第十一章学习感悟》
    201671010146 2017-2 《第十章学习感悟》
    201671010146 2017―2 《第11周学习java有感》
    201671010146 2017-2 《java第八章学习感悟》
  • 原文地址:https://www.cnblogs.com/DJames23/p/13395699.html
Copyright © 2011-2022 走看看