zoukankan      html  css  js  c++  java
  • TensorFlow 制作自己数据集时,xml转csv

    TensorFlow 制作自己数据集时,xml转csv千篇一律,把我拐入坑里了。

    如果训练自己的数据集只有一个类别,用网络上的xml_to_csv,完全没有问题,源码如下:

    # -*- coding: utf-8 -*-
    import os
    import glob
    import pandas as pd
    import xml.etree.ElementTree as ET
     
    def xml_to_csv(path):
        xml_list = []
        # 读取注释文件
        for xml_file in glob.glob(path + '/*.xml'):
            tree = ET.parse(xml_file)
            root = tree.getroot()
            for member in root.findall('object'):
                value = (root.find('filename').text + '.jpg',
                         int(root.find('size')[0].text),
                         int(root.find('size')[1].text),
                         member[0].text,
                         int(member[4][0].text),
                         int(member[4][1].text),
                         int(member[4][2].text),
                         int(member[4][3].text)
                         )
                xml_list.append(value)
        column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
     
        # 将所有数据分为样本集和验证集,一般按照3:1的比例
        train_list = xml_list[0: int(len(xml_list) * 0.67)]
        eval_list = xml_list[int(len(xml_list) * 0.67) + 1: ]
     
        # 保存为CSV格式
        train_df = pd.DataFrame(train_list, columns=column_name)
        eval_df = pd.DataFrame(eval_list, columns=column_name)
        train_df.to_csv('data/train.csv', index=None)
        eval_df.to_csv('data/eval.csv', index=None)
     
     
    def main():
        path = './xml'
        xml_to_csv(path)
        print('Successfully converted xml to csv.')
     
    main()
    

      

    如果你的类别数据集,超过2类以上,再用上述源码,觉得把所有的数据集3:1的分割,而非一个类别的3:1分割 。

    对上述源码略作调整,完美把每一类数据集按照9:1分割为训练数据集和测试数据集,源代码如下:

    # coding: utf-8
    import glob
    import pandas as pd
    import xml.etree.ElementTree as ET
     
    classes = ["20Km_h", "no_passing_35", "no_passing", "keep_left", "keep_right", "mandatory", "straight_or_left", "passing_limits",
               "bicycles", "pedestrians", "stop", "dangerous"]
     
    def xml_to_csv(path):
        train_list = []
        eval_list = []
     
        for cls in classes:
            xml_list = []
            # 读取注释文件
            for xml_file in glob.glob(path + '/*.xml'):
                tree = ET.parse(xml_file)
                root = tree.getroot()
                for member in root.findall('object'):
                    if cls == member[0].text:
                        value = (root.find('filename').text,
                                 int(root.find('size')[0].text),
                                 int(root.find('size')[1].text),
                                 member[0].text,
                                 int(member[4][0].text),
                                 int(member[4][1].text),
                                 int(member[4][2].text),
                                 int(member[4][3].text)
                                 )
                        xml_list.append(value)
     
            for i in range(0,int(len(xml_list) * 0.9)):
                train_list.append(xml_list[i])
            for j in range(int(len(xml_list) * 0.9) + 1,int(len(xml_list))):
                eval_list.append(xml_list[j])
     
        column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
     
     
        # 保存为CSV格式
        train_df = pd.DataFrame(train_list, columns=column_name)
        eval_df = pd.DataFrame(eval_list, columns=column_name)
        train_df.to_csv('data/train.csv', index=None)
        eval_df.to_csv('data/eval.csv', index=None)
     
     
    def main():
        # path = 'E:\data\Images'
        path = r'D:workPycharmPro	rafficsignSSD_NETdataxml_data'  # path参数更具自己xml文件所在的文件夹路径修改
        xml_to_csv(path)
        print('Successfully converted xml to csv.')
     
     
    main()
    

      

    classes = ["20Km_h", "no_passing_35", "no_passing", "keep_left", "keep_right", "mandatory", "straight_or_left", "passing_limits", "bicycles", "pedestrians", "stop", "dangerous"]

    该处需要改为自己数据集类别标签名。


    原文:https://blog.csdn.net/miao0967020148/article/details/90208139

  • 相关阅读:
    java学习多线程之创建多线程一
    java学习之线程
    ios开发系统地图知识
    Swift3.0变化分享
    最新友盟6.1.1集成遇到的坑,自定义分享界面实现(跳转控制器做分享)
    IOS开发遇到(null)与<null>轻松处理
    友盟分享实现
    iOS 判断网络连接状态的几种方法
    GCD总结
    iOS视频边下边播--缓存播放数据流
  • 原文地址:https://www.cnblogs.com/qbdj/p/11024547.html
Copyright © 2011-2022 走看看