zoukankan      html  css  js  c++  java
  • Detectron训练并调用自己的数据集(附yarm文件说明)

    1.创建如下目录

     2.将train_annotations中的xml文件转成json文件

    # coding=utf-8
    import xml.etree.ElementTree as ET
    import os
    import json
    import collections
    
    coco = dict()
    coco['images'] = []
    coco['type'] = 'instances'
    coco['annotations'] = []
    coco['categories'] = []
    
    #category_set = dict()
    image_set = set()
    image_id = 2019100001  # train:2018xxx; val:2019xxx; test:2020xxx
    category_item_id = 1
    annotation_id = 1
    category_set = ['people', 'bicycle', 'electric bicycle'] #这里填写好检测的类别
    '''
    def addCatItem(name):
        global category_item_id
        category_item = dict()
        category_item['supercategory'] = 'none'
        category_item_id += 1
        category_item['id'] = category_item_id
        category_item['name'] = name
        coco['categories'].append(category_item)
        category_set[name] = category_item_id
        return category_item_id
    '''
    
    
    def addCatItem(name):
        '''
        增加json格式中的categories部分
        '''
        global category_item_id
        category_item = collections.OrderedDict()
        category_item['supercategory'] = 'none'
        category_item['id'] = category_item_id
        category_item['name'] = name
        coco['categories'].append(category_item)
        category_item_id += 1
    
    
    def addImgItem(file_name, size):
        global image_id
        if file_name is None:
            raise Exception('Could not find filename tag in xml file.')
        if size['width'] is None:
            raise Exception('Could not find width tag in xml file.')
        if size['height'] is None:
            raise Exception('Could not find height tag in xml file.')
        # image_item = dict()    #按照一定的顺序,这里采用collections.OrderedDict()
        image_item = collections.OrderedDict()
        jpg_name = os.path.splitext(file_name)[0] + '.jpg'
        image_item['file_name'] = jpg_name
        image_item['width'] = size['width']
        image_item['height'] = size['height']
        image_item['id'] = image_id
        coco['images'].append(image_item)
        image_set.add(jpg_name)
        image_id = image_id + 1
        return image_id
    
    
    def addAnnoItem(object_name, image_id, category_id, bbox):
        global annotation_id
        #annotation_item = dict()
        annotation_item = collections.OrderedDict()
        annotation_item['segmentation'] = []
        seg = []
        # bbox[] is x,y,w,h
        # left_top
        seg.append(bbox[0])
        seg.append(bbox[1])
        # left_bottom
        seg.append(bbox[0])
        seg.append(bbox[1] + bbox[3])
        # right_bottom
        seg.append(bbox[0] + bbox[2])
        seg.append(bbox[1] + bbox[3])
        # right_top
        seg.append(bbox[0] + bbox[2])
        seg.append(bbox[1])
        annotation_item['segmentation'].append(seg)
        annotation_item['area'] = bbox[2] * bbox[3]
        annotation_item['iscrowd'] = 0
        annotation_item['image_id'] = image_id
        annotation_item['bbox'] = bbox
        annotation_item['category_id'] = category_id
        annotation_item['id'] = annotation_id
        annotation_item['ignore'] = 0
        annotation_id += 1
        coco['annotations'].append(annotation_item)
    
    
    def parseXmlFiles(xml_path):
        xmllist = os.listdir(xml_path)
        xmllist.sort()
        for f in xmllist:
            if not f.endswith('.xml'):
                continue
    
            bndbox = dict()
            size = dict()
            current_image_id = None
            current_category_id = None
            file_name = None
            size['width'] = None
            size['height'] = None
            size['depth'] = None
    
            xml_file = os.path.join(xml_path, f)
            print(xml_file)
    
            tree = ET.parse(xml_file)
            root = tree.getroot()  # 抓根结点元素
    
            if root.tag != 'annotation':  # 根节点标签
                raise Exception(
                    'pascal voc xml root element should be annotation, rather than {}'.format(root.tag))
    
            # elem is <folder>, <filename>, <size>, <object>
            for elem in root:
                current_parent = elem.tag
                current_sub = None
                object_name = None
    
                # elem.tag, elem.attrib,elem.text
                if elem.tag == 'folder':
                    continue
    
                if elem.tag == 'filename':
                    file_name = elem.text
                    if file_name in category_set:
                        raise Exception('file_name duplicated')
    
                # add img item only after parse <size> tag
                elif current_image_id is None and file_name is not None and size['width'] is not None:
                    if file_name not in image_set:
                        current_image_id = addImgItem(file_name, size)  # 图片信息
                        print('add image with {} and {}'.format(file_name, size))
                    else:
                        raise Exception('duplicated image: {}'.format(file_name))
                        # subelem is <width>, <height>, <depth>, <name>, <bndbox>
                for subelem in elem:
                    bndbox['xmin'] = None
                    bndbox['xmax'] = None
                    bndbox['ymin'] = None
                    bndbox['ymax'] = None
    
                    current_sub = subelem.tag
                    if current_parent == 'object' and subelem.tag == 'name':
                        object_name = subelem.text
                        # if object_name not in category_set:
                        #    current_category_id = addCatItem(object_name)
                        # else:
                        #current_category_id = category_set[object_name]
                        current_category_id = category_set.index(
                            object_name) + 1  # index默认从0开始,但是json文件是从1开始,所以+1
                    elif current_parent == 'size':
                        if size[subelem.tag] is not None:
                            raise Exception('xml structure broken at size tag.')
                        size[subelem.tag] = int(subelem.text)
    
                    # option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
                    for option in subelem:
                        if current_sub == 'bndbox':
                            if bndbox[option.tag] is not None:
                                raise Exception(
                                    'xml structure corrupted at bndbox tag.')
                            bndbox[option.tag] = int(option.text)
    
                    # only after parse the <object> tag
                    if bndbox['xmin'] is not None:
                        if object_name is None:
                            raise Exception('xml structure broken at bndbox tag')
                        if current_image_id is None:
                            raise Exception('xml structure broken at bndbox tag')
                        if current_category_id is None:
                            raise Exception('xml structure broken at bndbox tag')
                        bbox = []
                        # x
                        bbox.append(bndbox['xmin'])
                        # y
                        bbox.append(bndbox['ymin'])
                        # w
                        bbox.append(bndbox['xmax'] - bndbox['xmin'])
                        # h
                        bbox.append(bndbox['ymax'] - bndbox['ymin'])
                        print(
                            'add annotation with {},{},{},{}'.format(object_name, current_image_id - 1, current_category_id, bbox))
                        addAnnoItem(object_name, current_image_id -
                                    1, current_category_id, bbox)
        # categories部分
        for categoryname in category_set:
            addCatItem(categoryname)
    
    
    if __name__ == '__main__':
        xml_path = 'dataset/train_anatations'
        json_file = 'VOC2007/anatations/voc_2007_train.json'
        # xml_path = 'dataset/test_anatation'
        # json_file = 'dataset/test.json'
        parseXmlFiles(xml_path)
        json.dump(coco, open(json_file, 'w'))

    3.制作txt文件

    # -*- coding: utf-8 -*-
    # @Author: zhiwei
    # @Date:   2019-01-31 09:38:58
    # @Last Modified by:   zhiwei
    # @Last Modified time: 2019-01-31 10:21:58
    #
    import os
    import re
    
    fp1_path = "JPEGImages"
    f = open('VOCdevkit2020/VOC2020/ImageSets/Main/train.txt', 'w')
    s = ""
    i = 0
    for filename in os.listdir(fp1_path):
        # bicycle_train.txt
        # if re.match('^bicycle.+', filename) != None:
        #     s1 = os.path.splitext(os.path.basename(filename))[0]
        #     s = s + s1 + ' ' + str(1) + "
    "
        # else:
        #     s1 = os.path.splitext(os.path.basename(filename))[0]
        #     s = s + s1 + ' ' + str(0) + "
    "
    
        # electric bicycle_train.txt
        # if "electric_bicycle" in filename:
        #     s1 = os.path.splitext(os.path.basename(filename))[0]
        #     s = s + s1 + ' ' + str(1) + "
    "
        # else:
        #     s1 = os.path.splitext(os.path.basename(filename))[0]
        #     s = s + s1 + ' ' + str(0) + "
    "
    
        # train.txt
        s1 = os.path.splitext(os.path.basename(filename))[0]
        s = s + s1 + "
    "
    
    print(s)
    
    f.write(s)
    print(len([name for name in os.listdir(fp1_path)]))

    4.配置yarm文件

    在Detectronconfigs12_2017_baselines目录下,复制文件retinanet_R-50-FPN_1x.yaml,到Detectronconfigsmy目录下重命名为retinanet_R-50-FPN_1x1.0.yaml

    5.修改yarm文件 (retinanet_R-50-FPN_1x1.0.yaml)

    MODEL:
      TYPE: retinanet
      CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
      NUM_CLASSES: 4
    NUM_GPUS: 1
    SOLVER:
      WEIGHT_DECAY: 0.0001
      LR_POLICY: steps_with_decay
      BASE_LR: 0.001
      GAMMA: 0.1
      MAX_ITER: 1000
      STEPS: [0, 600, 800]
    FPN:
      FPN_ON: True
      MULTILEVEL_RPN: True
      RPN_MAX_LEVEL: 7
      RPN_MIN_LEVEL: 3
      COARSEST_STRIDE: 128
      EXTRA_CONV_LEVELS: True
    RETINANET:
      RETINANET_ON: True
      NUM_CONVS: 4
      ASPECT_RATIOS: (1.0, 2.0, 0.5)
      SCALES_PER_OCTAVE: 3
      ANCHOR_SCALE: 4
      LOSS_GAMMA: 2.0
      LOSS_ALPHA: 0.25
    TRAIN:
      WEIGHTS: /home/Desktop/test/trainMOdel/R-50.pkl
      DATASETS: ('voc_2007_train',)
      SCALES: (800,)
      MAX_SIZE: 1333
      RPN_STRADDLE_THRESH: -1  # default 0
    TEST:
      DATASETS: ('coco_2014_minival',)
      SCALE: 800
      MAX_SIZE: 1333
      NMS: 0.5
      RPN_PRE_NMS_TOP_N: 10000  # Per FPN level
      RPN_POST_NMS_TOP_N: 2000
    OUTPUT_DIR: .

    我们这里只解释具体的几个参数的含义。

    1)cfg,是配置文件,均存在于configs目录下。
    在Model中:

    MODEL:
        TYPE: generalized_rcnn
        CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
        NUM_CLASSES: 81
        FASTER_RCNN: True

    其中需要初学者注意的是NUM_CLASSES,对于customer的数据集,该值为 类别数+1,因此对于coco来说就是80+1
    对于Mask网络,model部分还应加上

    MASK_ON: True

    设置GPU的数量:

    NUM_GPUS: 1

    SOLVER设置:

    SOLVER:
        WEIGHT_DECAY: 0.0001
        LR_POLICY: steps_with_decay
        BASE_LR: 0.0025
        GAMMA: 0.1
        MAX_ITER: 60000
        STEPS: [0, 30000, 40000]

    首先,对于训练次数而言,如果数据集不大,我们可以设置为几千次,如果像coco这类数据量较大,几万次还是有必要的。default的单GPU下为60000次。这里值得一提的是,对于多GPU下,MAX_ITER的次数与GPU的数量成反比。另一个需要说的参数是BASE_LR,初始的学习率对于网络训练很重要,太大会使网络不容易收敛到最小值,太小又会使网络收敛过慢。正如之前博客所言,通常取e-3到e-4比较安全。这里0.0025的取值,猜测是作者根据网络和数据多次实验所取的较优的配置。在多GPU下,BASE_LR的取值与GPU成正比。

    FPN和FAST RCNN设置:

    FPN:
        FPN_ON: True
        MULTILEVEL_ROIS: True
        MULTILEVEL_RPN: True
    FAST_RCNN:
        ROI_BOX_HEAD: fast_rcnn_heads.add_roi_2mlp_head
        ROI_XFORM_METHOD: RoIAlign
        ROI_XFORM_RESOLUTION: 7
        ROI_XFORM_SAMPLING_RATIO: 2

    曾经尝试改动FPN,但是与此同时model也应该改动,正常情况下这两项在customer数据集上不需要改动。

    Train 的设置:

    TRAIN:
        WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/MSRA/R-50.pkl
        DATASETS: ('coco_2014_train',)
        SCALES: (500,)
        MAX_SIZE: 833
        BATCH_SIZE_PER_IM: 256
        RPN_PRE_NMS_TOP_N: 2000  # Per FPN level

    这里要注意的是DATASETS中,自己的数据集需要在dataset_catalog.py中补充,其中数据集的生成方式github中也有详细说明

    Test 的设置:

    TEST:
        DATASETS: ('coco_2014_minival',)
        SCALE: 500
        MAX_SIZE: 833
        NMS: 0.5
        RPN_PRE_NMS_TOP_N: 1000  # Per FPN level
        RPN_POST_NMS_TOP_N: 1000

    其中,需要注意的包括,SCALE和MAX_SIZE,从他人处得到经验,想在inference时提高准确率其中的一个办法就是换成高分辨率,那么这两个参数就是需要改动的了。
    另外NMS的设置可以在非密集场景下减少重复出现的box

    6.开始训练

    cd到$Detectron/tools/目录下执行命令

    python train_net.py --cfg ../configs/my2020/retinanet_R-50-FPN_1x.0.yaml OUTPUT_DIR /home/gaomh/Desktop/test/trainMOdel
    • --cfg :配置文件路径
    • OUTPUT_DIR:训练的输出路径

    接下来就是自行训练过程了~

    7.遇到的问题

    INFO loader.py: 126: Stopping enqueue thread
    INFO loader.py: 113: Stopping mini-batch loading thread
    INFO loader.py: 113: Stopping mini-batch loading thread
    INFO loader.py: 113: Stopping mini-batch loading thread
    INFO loader.py: 113: Stopping mini-batch loading thread
    Traceback (most recent call last):
      File "tools/train_net.py", line 132, in <module>
        main()
      File "tools/train_net.py", line 114, in main
        checkpoints = detectron.utils.train.train_model()
      File "/home/learner/github/detectron/detectron/utils/train.py", line 86, in train_model
        handle_critical_error(model, 'Loss is NaN')

    修改BASE_LR的值从0.01改为0.001

  • 相关阅读:
    你知道require是什么吗?
    jQuery类库的设计
    多线程下载图片
    多线程与CPU和多线程与GIL
    一个python小爬虫
    一个方格表的问题
    使用django发布带图片的网页(上)
    uWSGI+Django+nginx(下)
    uWSGI+Django (中)
    Linux下安装Python3的django并配置mysql作为django默认数据库(转载)
  • 原文地址:https://www.cnblogs.com/answerThe/p/12120913.html
Copyright © 2011-2022 走看看