zoukankan      html  css  js  c++  java
  • 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。

    import xml.etree.ElementTree as ET
    import numpy as np
    import os
    import tensorflow as tf
    from PIL import Image
    
    classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
               "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    
    
    def convert(size, box):
        dw = 1./size[0]
        dh = 1./size[1]
        x = (box[0] + box[1])/2.0
        y = (box[2] + box[3])/2.0
        w = box[1] - box[0]
        h = box[3] - box[2]
        x = x*dw
        w = w*dw
        y = y*dh
        h = h*dh
        return [x, y, w, h]
    
    
    def convert_annotation(image_id):
        in_file = open('F:/xml/%s.xml'%(image_id))
    
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        bboxes = []
        for i, obj in enumerate(root.iter('object')):
            if i > 29:
                break
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
            bb = convert((w, h), b) + [cls_id]
            bboxes.extend(bb)
        if len(bboxes) < 30*5:
            bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5))
    
        return np.array(bboxes, dtype=np.float32).flatten().tolist()
    
    def convert_img(image_id):
        image = Image.open('F:/snow leopard/test_im/%s.jpg' % (image_id))
        resized_image = image.resize((416, 416), Image.BICUBIC)
        image_data = np.array(resized_image, dtype='float32')/255
        img_raw = image_data.tobytes()
        return img_raw
    
    filename = os.path.join('test'+'.tfrecords')
    writer = tf.python_io.TFRecordWriter(filename)
    # image_ids = open('F:/snow leopard/test_im/%s.txt' % (
    #     year, year, image_set)).read().strip().split()
    
    image_ids = os.listdir('F:/snow leopard/test_im/')
    # print(filename)
    for image_id in image_ids:
        print (image_id)
        image_id = image_id.split('.')[0]
        print (image_id)
    
        xywhc = convert_annotation(image_id)
        img_raw = convert_img(image_id)
    
        example = tf.train.Example(features=tf.train.Features(feature={
            'xywhc':
                    tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)),
            'img':
                    tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            }))
        writer.write(example.SerializeToString())
    writer.close()
    

      

    Python读取文件夹下图片的两种方法:

    import os
    imagelist = os.listdir('./images/')      #读取images文件夹下所有文件的名字
    import glob
    imagelist= sorted(glob.glob('./images/' + 'frame_*.png'))      #读取带有相同关键字的图片名字,比上一中方法好


    参考:

    https://blog.csdn.net/CV_YOU/article/details/80778392

    https://github.com/raytroop/YOLOv3_tf

  • 相关阅读:
    模板——二分法
    Trie Tree(静态数组写法,好写)
    欧拉路径 基础题 hiho第49周
    Fleury算法求欧拉路径 hiho第50周
    hdu 5266 pog loves szh III 在线lca+线段树区间优化
    hdu 5269 字典树
    hdu 5265 pog loves szh II
    poj 3678 2-sat(强连通)
    lca 在线,离线 poj 1330
    lca 在线算法 zoj 3195
  • 原文地址:https://www.cnblogs.com/Allen-rg/p/10245729.html
Copyright © 2011-2022 走看看