zoukankan      html  css  js  c++  java
  • yolov5训练自定义数据集

    yolov5训练自定义数据

    step1:参考文献及代码

    1. 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535
    2. github代码 https://github.com/DataXujing/YOLO-v5
    3. 官方代码 https://github.com/ultralytics/yolov5
    4. 官方教程 https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data

    step2:准备数据集

    • --yolov5需要的数据集格式为txt格式的(即一个图片对应一个txt文件)
    • 参考文献1利用其将xml格式的代码转换成txt格式的代码

    +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    • 更新:2021/2/6 0:42
    • 找到了跟好的转换数据集的github库---->可应用与yolov3和yolov5的训练
    • github地址:https://github.com/pprp/voc2007_for_yolo_torch
    • 如果自己的图片格式不是.jpg需要修改tools/make_for_yolov3_torch.py里面的代码

    +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    训练

    • 利用参考文献而将参考文献1中的labels中的txt数据集和images中的图片放入到参考文献二中

    附录:

    用于抽取训练集和测试集

    1. 抽取图片:抽取.py
    import os
    import random
    import shutil
     
    # source_file:源路径, target_ir:目标路径
    def cover_files(source_dir, target_ir):
        for file in os.listdir(source_dir):
            source_file = os.path.join(source_dir, file)
     
            if os.path.isfile(source_file):
                shutil.copy(source_file, target_ir)
     
    def ensure_dir_exists(dir_name):
        """Makes sure the folder exists on disk.
      Args:
        dir_name: Path string to the folder we want to create.
      """
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
     
     
    def moveFile(file_dir, save_dir):
        ensure_dir_exists(save_dir)
        path_dir = os.listdir(file_dir)  
        filenumber = len(path_dir)
        rate = 0.1  # 自定义抽取图片的比例,比方说100张抽10张,那就是0.1
        picknumber = int(filenumber * rate)  # 按照rate比例从文件夹中取一定数量图片
        sample = random.sample(path_dir, picknumber)  # 随机选取picknumber数量的样本图片
        # print (sample)
        for name in sample:
            shutil.move(file_dir + name, save_dir + name)
     
    #切记win10路径D:你的路径\,最后一定要有\才能进入目标文件  
    if __name__ == '__main__':
        file_dir = 'G:\ECANet-master\train\0\'  # 源图片文件夹路径
        save_dir = 'G:\ECANet-master\train\00\'  # 移动到目标文件夹路径
        moveFile(file_dir, save_dir)
    

    json2xml:(json格式转换成xml格式)

    • 将下面三个文件放入到json_to_xml文件夹下
    1. create_xml_anno.py
    # -*- coding: utf-8 -*-
    from xml.dom.minidom import Document
     
     
    class CreateAnno:
        def __init__(self,):
            self.doc = Document()  # 创建DOM文档对象
            self.anno = self.doc.createElement('annotation')  # 创建根元素
            self.doc.appendChild(self.anno)
     
            self.add_folder()
            self.add_path()
            self.add_source()
            self.add_segmented()
     
            # self.add_filename()
            # self.add_pic_size(width_text_str=str(width), height_text_str=str(height), depth_text_str=str(depth))
     
        def add_folder(self, floder_text_str='JPEGImages'):
            floder = self.doc.createElement('floder')  ##建立自己的开头
            floder_text = self.doc.createTextNode(floder_text_str)  ##建立自己的文本信息
            floder.appendChild(floder_text)  ##自己的内容
            self.anno.appendChild(floder)
     
        def add_filename(self, filename_text_str='00000.jpg'):
            filename = self.doc.createElement('filename')
            filename_text = self.doc.createTextNode(filename_text_str)
            filename.appendChild(filename_text)
            self.anno.appendChild(filename)
     
        def add_path(self, path_text_str="None"):
            path = self.doc.createElement('path')
            path_text = self.doc.createTextNode(path_text_str)
            path.appendChild(path_text)
            self.anno.appendChild(path)
     
        def add_source(self, database_text_str="Unknow"):
            source = self.doc.createElement('source')
            database = self.doc.createElement('database')
            database_text = self.doc.createTextNode(database_text_str)  # 元素内容写入
            database.appendChild(database_text)
            source.appendChild(database)
            self.anno.appendChild(source)
     
        def add_pic_size(self, width_text_str="0", height_text_str="0", depth_text_str="3"):
            size = self.doc.createElement('size')
            width = self.doc.createElement('width')
            width_text = self.doc.createTextNode(width_text_str)  # 元素内容写入
            width.appendChild(width_text)
            size.appendChild(width)
     
            height = self.doc.createElement('height')
            height_text = self.doc.createTextNode(height_text_str)
            height.appendChild(height_text)
            size.appendChild(height)
     
            depth = self.doc.createElement('depth')
            depth_text = self.doc.createTextNode(depth_text_str)
            depth.appendChild(depth_text)
            size.appendChild(depth)
     
            self.anno.appendChild(size)
     
        def add_segmented(self, segmented_text_str="0"):
            segmented = self.doc.createElement('segmented')
            segmented_text = self.doc.createTextNode(segmented_text_str)
            segmented.appendChild(segmented_text)
            self.anno.appendChild(segmented)
     
        def add_object(self,
                       name_text_str="None",
                       xmin_text_str="0",
                       ymin_text_str="0",
                       xmax_text_str="0",
                       ymax_text_str="0",
                       pose_text_str="Unspecified",
                       truncated_text_str="0",
                       difficult_text_str="0"):
            object = self.doc.createElement('object')
            name = self.doc.createElement('name')
            name_text = self.doc.createTextNode(name_text_str)
            name.appendChild(name_text)
            object.appendChild(name)
     
            pose = self.doc.createElement('pose')
            pose_text = self.doc.createTextNode(pose_text_str)
            pose.appendChild(pose_text)
            object.appendChild(pose)
     
            truncated = self.doc.createElement('truncated')
            truncated_text = self.doc.createTextNode(truncated_text_str)
            truncated.appendChild(truncated_text)
            object.appendChild(truncated)
     
            difficult = self.doc.createElement('difficult')
            difficult_text = self.doc.createTextNode(difficult_text_str)
            difficult.appendChild(difficult_text)
            object.appendChild(difficult)
     
            bndbox = self.doc.createElement('bndbox')
            xmin = self.doc.createElement('xmin')
            xmin_text = self.doc.createTextNode(xmin_text_str)
            xmin.appendChild(xmin_text)
            bndbox.appendChild(xmin)
     
            ymin = self.doc.createElement('ymin')
            ymin_text = self.doc.createTextNode(ymin_text_str)
            ymin.appendChild(ymin_text)
            bndbox.appendChild(ymin)
     
            xmax = self.doc.createElement('xmax')
            xmax_text = self.doc.createTextNode(xmax_text_str)
            xmax.appendChild(xmax_text)
            bndbox.appendChild(xmax)
     
            ymax = self.doc.createElement('ymax')
            ymax_text = self.doc.createTextNode(ymax_text_str)
            ymax.appendChild(ymax_text)
            bndbox.appendChild(ymax)
            object.appendChild(bndbox)
     
            self.anno.appendChild(object)
     
        def get_anno(self):
            return self.anno
     
        def get_doc(self):
            return self.doc
     
        def save_doc(self, save_path):
            with open(save_path, "w") as f:
                self.doc.writexml(f, indent='	', newl='
    ', addindent='	', encoding='utf-8')
    
    1. main.py
    import os
    from tqdm import tqdm
     
    from read_json import ReadAnno
    from create_xml_anno import CreateAnno
     
     
    def json_transform_xml(json_path, xml_path, process_mode="polygon"):
        json_path = json_path
        json_anno = ReadAnno(json_path, process_mode=process_mode)
        width, height = json_anno.get_width_height()
        filename = json_anno.get_filename()
        coordis = json_anno.get_coordis()
     
        xml_anno = CreateAnno()
        xml_anno.add_filename(filename)
        xml_anno.add_pic_size(width_text_str=str(width), height_text_str=str(height), depth_text_str=str(3))
        for xmin,ymin,xmax,ymax,label in coordis:
            xml_anno.add_object(name_text_str=str(label),
                                xmin_text_str=str(int(xmin)),
                                ymin_text_str=str(int(ymin)),
                                xmax_text_str=str(int(xmax)),
                                ymax_text_str=str(int(ymax)))
     
        xml_anno.save_doc(xml_path)
     
    if __name__ == "__main__":
        root_json_dir = r"/home/aibc/ouyang/temp_dataset/jjson"
        root_save_xml_dir = r"/home/aibc/ouyang/temp_dataset/xml"
        for json_filename in tqdm(os.listdir(root_json_dir)):
            json_path = os.path.join(root_json_dir, json_filename)
            save_xml_path = os.path.join(root_save_xml_dir, json_filename.replace(".json", ".xml"))
            json_transform_xml(json_path, save_xml_path, process_mode="polygon")
    
    1. read_json.py
    # -*- coding: utf-8 -*-
    import numpy as np
    import json
     
     
    class ReadAnno:
        def __init__(self, json_path, process_mode="rectangle"):
            self.json_data = json.load(open(json_path))
            self.filename = self.json_data['imagePath']
            self.width = self.json_data['imageWidth']
            self.height = self.json_data['imageHeight']
     
            self.coordis = []
            assert process_mode in ["rectangle", "polygon"]
            if process_mode == "rectangle":
                self.process_polygon_shapes()
            elif process_mode == "polygon":
                self.process_polygon_shapes()
     
        def process_rectangle_shapes(self):
            for single_shape in self.json_data['shapes']:
                bbox_class = single_shape['label']
                xmin = single_shape['points'][0][0]
                ymin = single_shape['points'][0][1]
                xmax = single_shape['points'][1][0]
                ymax = single_shape['points'][1][1]
                self.coordis.append([xmin,ymin,xmax,ymax,bbox_class])
     
        def process_polygon_shapes(self):
            for single_shape in self.json_data['shapes']:
                bbox_class = single_shape['label']
                temp_points = []
                for couple_point in single_shape['points']:
                    x = float(couple_point[0])
                    y = float(couple_point[1])
                    temp_points.append([x,y])
                temp_points = np.array(temp_points)
                xmin, ymin = temp_points.min(axis=0)
                xmax, ymax = temp_points.max(axis=0)
                self.coordis.append([xmin,ymin,xmax,ymax,bbox_class])
     
        def get_width_height(self):
            return self.width, self.height
     
        def get_filename(self):
            return self.filename
     
        def get_coordis(self):
            return self.coordis
    
  • 相关阅读:
    curl java 模拟http请求
    Redis 个人理解总结
    算法时间复杂度的表示法O(n²)、O(n)、O(1)、O(nlogn)等是什么意思?
    RESTful 个人理解总结
    springcloud(五):熔断监控Hystrix Dashboard和Turbine
    springcloud(四):熔断器Hystrix
    springcloud(三):服务提供与调用
    springcloud(二):注册中心Eureka
    springcloud(一):大话Spring Cloud
    Spring Cloud在国内中小型公司用的起来吗?
  • 原文地址:https://www.cnblogs.com/zranguai/p/14378722.html
Copyright © 2011-2022 走看看