zoukankan      html  css  js  c++  java
  • 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。

    import xml.etree.ElementTree as ET
    import numpy as np
    import os
    import tensorflow as tf
    from PIL import Image
    
    classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
               "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    
    
    def convert(size, box):
        dw = 1./size[0]
        dh = 1./size[1]
        x = (box[0] + box[1])/2.0
        y = (box[2] + box[3])/2.0
        w = box[1] - box[0]
        h = box[3] - box[2]
        x = x*dw
        w = w*dw
        y = y*dh
        h = h*dh
        return [x, y, w, h]
    
    
    def convert_annotation(image_id):
        in_file = open('F:/xml/%s.xml'%(image_id))
    
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        bboxes = []
        for i, obj in enumerate(root.iter('object')):
            if i > 29:
                break
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
            bb = convert((w, h), b) + [cls_id]
            bboxes.extend(bb)
        if len(bboxes) < 30*5:
            bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5))
    
        return np.array(bboxes, dtype=np.float32).flatten().tolist()
    
    def convert_img(image_id):
        image = Image.open('F:/snow leopard/test_im/%s.jpg' % (image_id))
        resized_image = image.resize((416, 416), Image.BICUBIC)
        image_data = np.array(resized_image, dtype='float32')/255
        img_raw = image_data.tobytes()
        return img_raw
    
    filename = os.path.join('test'+'.tfrecords')
    writer = tf.python_io.TFRecordWriter(filename)
    # image_ids = open('F:/snow leopard/test_im/%s.txt' % (
    #     year, year, image_set)).read().strip().split()
    
    image_ids = os.listdir('F:/snow leopard/test_im/')
    # print(filename)
    for image_id in image_ids:
        print (image_id)
        image_id = image_id.split('.')[0]
        print (image_id)
    
        xywhc = convert_annotation(image_id)
        img_raw = convert_img(image_id)
    
        example = tf.train.Example(features=tf.train.Features(feature={
            'xywhc':
                    tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)),
            'img':
                    tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            }))
        writer.write(example.SerializeToString())
    writer.close()
    

      

    Python读取文件夹下图片的两种方法:

    import os
    imagelist = os.listdir('./images/')      #读取images文件夹下所有文件的名字
    import glob
    imagelist= sorted(glob.glob('./images/' + 'frame_*.png'))      #读取带有相同关键字的图片名字,比上一中方法好


    参考:

    https://blog.csdn.net/CV_YOU/article/details/80778392

    https://github.com/raytroop/YOLOv3_tf

  • 相关阅读:
    27.TreeMap
    26.HashCode
    25.HashTable
    myeclipse快捷键
    spring 配置
    jdbcType和javaType对应关系
    Ajax表单提交
    ajax
    JQuery及Form插件使用
    jsp标准数据库
  • 原文地址:https://www.cnblogs.com/Allen-rg/p/10245729.html
Copyright © 2011-2022 走看看