zoukankan      html  css  js  c++  java
  • 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。

    import xml.etree.ElementTree as ET
    import numpy as np
    import os
    import tensorflow as tf
    from PIL import Image
    
    classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
               "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    
    
    def convert(size, box):
        dw = 1./size[0]
        dh = 1./size[1]
        x = (box[0] + box[1])/2.0
        y = (box[2] + box[3])/2.0
        w = box[1] - box[0]
        h = box[3] - box[2]
        x = x*dw
        w = w*dw
        y = y*dh
        h = h*dh
        return [x, y, w, h]
    
    
    def convert_annotation(image_id):
        in_file = open('F:/xml/%s.xml'%(image_id))
    
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        bboxes = []
        for i, obj in enumerate(root.iter('object')):
            if i > 29:
                break
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
            bb = convert((w, h), b) + [cls_id]
            bboxes.extend(bb)
        if len(bboxes) < 30*5:
            bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5))
    
        return np.array(bboxes, dtype=np.float32).flatten().tolist()
    
    def convert_img(image_id):
        image = Image.open('F:/snow leopard/test_im/%s.jpg' % (image_id))
        resized_image = image.resize((416, 416), Image.BICUBIC)
        image_data = np.array(resized_image, dtype='float32')/255
        img_raw = image_data.tobytes()
        return img_raw
    
    filename = os.path.join('test'+'.tfrecords')
    writer = tf.python_io.TFRecordWriter(filename)
    # image_ids = open('F:/snow leopard/test_im/%s.txt' % (
    #     year, year, image_set)).read().strip().split()
    
    image_ids = os.listdir('F:/snow leopard/test_im/')
    # print(filename)
    for image_id in image_ids:
        print (image_id)
        image_id = image_id.split('.')[0]
        print (image_id)
    
        xywhc = convert_annotation(image_id)
        img_raw = convert_img(image_id)
    
        example = tf.train.Example(features=tf.train.Features(feature={
            'xywhc':
                    tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)),
            'img':
                    tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            }))
        writer.write(example.SerializeToString())
    writer.close()
    

      

    Python读取文件夹下图片的两种方法:

    import os
    imagelist = os.listdir('./images/')      #读取images文件夹下所有文件的名字
    import glob
    imagelist= sorted(glob.glob('./images/' + 'frame_*.png'))      #读取带有相同关键字的图片名字,比上一中方法好


    参考:

    https://blog.csdn.net/CV_YOU/article/details/80778392

    https://github.com/raytroop/YOLOv3_tf

  • 相关阅读:
    信息系统项目管理师2009年上午试题分析与解答
    信息系统项目管理师2005年上半年试题
    信息系统项目管理师2008年下半年试题
    信息系统项目管理师历年上午试题答案及试题和大纲
    信息系统项目管理师2008年上半年试题
    信息系统项目管理师2005年下半年试题
    信息系统项目管理师2006年下半年试题
    一个经典的问题(构造函数调用+抽象类+间接继承抽象类)
    重载构造函数+复用构造函数+原始构造与This引用的区别(一步步案例分析)
    GetType()与Typeof()的区别 举了2个案例
  • 原文地址:https://www.cnblogs.com/Allen-rg/p/10245729.html
Copyright © 2011-2022 走看看