zoukankan      html  css  js  c++  java
  • 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练。

    import xml.etree.ElementTree as ET
    import numpy as np
    import os
    import tensorflow as tf
    from PIL import Image
    
    classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
               "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    
    
    def convert(size, box):
        dw = 1./size[0]
        dh = 1./size[1]
        x = (box[0] + box[1])/2.0
        y = (box[2] + box[3])/2.0
        w = box[1] - box[0]
        h = box[3] - box[2]
        x = x*dw
        w = w*dw
        y = y*dh
        h = h*dh
        return [x, y, w, h]
    
    
    def convert_annotation(image_id):
        in_file = open('F:/xml/%s.xml'%(image_id))
    
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        bboxes = []
        for i, obj in enumerate(root.iter('object')):
            if i > 29:
                break
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
            bb = convert((w, h), b) + [cls_id]
            bboxes.extend(bb)
        if len(bboxes) < 30*5:
            bboxes = bboxes + [0, 0, 0, 0, 0]*(30-int(len(bboxes)/5))
    
        return np.array(bboxes, dtype=np.float32).flatten().tolist()
    
    def convert_img(image_id):
        image = Image.open('F:/snow leopard/test_im/%s.jpg' % (image_id))
        resized_image = image.resize((416, 416), Image.BICUBIC)
        image_data = np.array(resized_image, dtype='float32')/255
        img_raw = image_data.tobytes()
        return img_raw
    
    filename = os.path.join('test'+'.tfrecords')
    writer = tf.python_io.TFRecordWriter(filename)
    # image_ids = open('F:/snow leopard/test_im/%s.txt' % (
    #     year, year, image_set)).read().strip().split()
    
    image_ids = os.listdir('F:/snow leopard/test_im/')
    # print(filename)
    for image_id in image_ids:
        print (image_id)
        image_id = image_id.split('.')[0]
        print (image_id)
    
        xywhc = convert_annotation(image_id)
        img_raw = convert_img(image_id)
    
        example = tf.train.Example(features=tf.train.Features(feature={
            'xywhc':
                    tf.train.Feature(float_list=tf.train.FloatList(value=xywhc)),
            'img':
                    tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            }))
        writer.write(example.SerializeToString())
    writer.close()
    

      

    Python读取文件夹下图片的两种方法:

    import os
    imagelist = os.listdir('./images/')      #读取images文件夹下所有文件的名字
    import glob
    imagelist= sorted(glob.glob('./images/' + 'frame_*.png'))      #读取带有相同关键字的图片名字,比上一中方法好


    参考:

    https://blog.csdn.net/CV_YOU/article/details/80778392

    https://github.com/raytroop/YOLOv3_tf

  • 相关阅读:
    Tomcat 7.x 6.x 和 JDK 7 旧版本下载教程
    下载ios系统文件,使用UltraISO刻录系统(windows Server)光盘,安装操作系统
    MSDN, 我告诉你
    669. 修剪二叉搜索树
    99. 恢复二叉搜索树
    windows系统 python安装uwsgi教程
    mysql5.7修改密码命令
    客户端浏览器访问django服务器
    二、kafka的环境部署+命令行
    二十、SpringCloud Alibaba Seata处理分布式事务(一、基础)
  • 原文地址:https://www.cnblogs.com/Allen-rg/p/10245729.html
Copyright © 2011-2022 走看看