zoukankan      html  css  js  c++  java
  • 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。

    import xml.etree.ElementTree as ET
    import numpy as np
    import os
    import tensorflow as tf
    from PIL import Image
    
    classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
               "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    
    
    def convert(size, box):
        dw = 1./size[0]
        dh = 1./size[1]
        x = (box[0] + box[1])/2.0
        y = (box[2] + box[3])/2.0
        w = box[1] - box[0]
        h = box[3] - box[2]
        x = x*dw
        w = w*dw
        y = y*dh
        h = h*dh
        return [x, y, w, h]
    
    
    def convert_annotation(image_id):
        in_file = open('F:/xml/%s.xml'%(image_id))
    
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        bboxes = []
        for i, obj in enumerate(root.iter('object')):
            if i > 29:
                break
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
            bb = convert((w, h), b) + [cls_id]
            bboxes.extend(bb)
        if len(bboxes) < 30*5:
            bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5))
    
        return np.array(bboxes, dtype=np.float32).flatten().tolist()
    
    def convert_img(image_id):
        image = Image.open('F:/snow leopard/test_im/%s.jpg' % (image_id))
        resized_image = image.resize((416, 416), Image.BICUBIC)
        image_data = np.array(resized_image, dtype='float32')/255
        img_raw = image_data.tobytes()
        return img_raw
    
    filename = os.path.join('test'+'.tfrecords')
    writer = tf.python_io.TFRecordWriter(filename)
    # image_ids = open('F:/snow leopard/test_im/%s.txt' % (
    #     year, year, image_set)).read().strip().split()
    
    image_ids = os.listdir('F:/snow leopard/test_im/')
    # print(filename)
    for image_id in image_ids:
        print (image_id)
        image_id = image_id.split('.')[0]
        print (image_id)
    
        xywhc = convert_annotation(image_id)
        img_raw = convert_img(image_id)
    
        example = tf.train.Example(features=tf.train.Features(feature={
            'xywhc':
                    tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)),
            'img':
                    tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            }))
        writer.write(example.SerializeToString())
    writer.close()
    

      

    Python读取文件夹下图片的两种方法:

    import os
    imagelist = os.listdir('./images/')      #读取images文件夹下所有文件的名字
    import glob
    imagelist= sorted(glob.glob('./images/' + 'frame_*.png'))      #读取带有相同关键字的图片名字,比上一中方法好


    参考:

    https://blog.csdn.net/CV_YOU/article/details/80778392

    https://github.com/raytroop/YOLOv3_tf

  • 相关阅读:
    柔性数组
    2015阿里秋招当中一个算法题(经典)
    LAMP环境搭建
    JS和JQuery中的事件托付 学习笔记
    #17 Letter Combinations of a Phone Number
    码农生涯杂记_5
    【C++ Primer每日刷】之三 标准库 string 类型
    扎根本地连接未来 千米网的电商“红海”生存术
    poj 3356
    经验之谈—OAuth授权流程图
  • 原文地址:https://www.cnblogs.com/Allen-rg/p/10245729.html
Copyright © 2011-2022 走看看