zoukankan      html  css  js  c++  java
  • mmdetection2损失为nan

    好久没用mmdetection了,今天用visdrone数据集训练一个cascade-rcnn-r101模型,损失总是会出现nan,先考率学习率是否太高的问题,学习率分别设为0.02*batchsize/16,0.000001,0,仍然出现损失为nan。所以怀疑是数据问题,可能存在无效的目标框(目标框的左下角坐标<=右上角坐标),具体判别代码:

    import xml.etree.ElementTree as ET
    import os
    xml_root = "./data"
    new_xml_root = "./data"
    image_root = "./data"
    xml_name_list = sorted(os.listdir(xml_root))
    def check_bbox():
        if not os.path.exists(new_xml_root):
            os.makedirs(new_xml_root)
    
        for xml_name in xml_name_list:
            xml_path = os.path.join(xml_root, xml_name)
            tree = ET.parse(xml_path)
            root = tree.getroot()
            for obj in root.findall("object"):
                bnd_box = obj.find("bndbox")
                bbox = [
                    int(float(bnd_box.find("xmin").text)),
                    int(float(bnd_box.find("ymin").text)),
                    int(float(bnd_box.find("xmax").text)),
                    int(float(bnd_box.find("ymax").text)),
                ]
    
                if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]:
                    print("bbox[0] >= bbox[2] or bbox[1] >= bbox[3]", bbox, xml_name)
    check_bbox()
    

    现记录使用mmdetection2训练visdrone的具体过程

    • 处理visdrone数据,将其txt标签转为VOC格式
    import os
    from PIL import Image
    
    root_dir = "/mnt/A/pengyuan/data/Visd2019/trainval/"
    annotations_dir = root_dir+"annotations/"
    image_dir = root_dir + "images/"
    xml_dir = root_dir+"Annotations/"  
    class_name = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']
    
    for filename in os.listdir(annotations_dir):
        fin = open(annotations_dir+filename, 'r')
        image_name = filename.split('.')[0]
        img = Image.open(image_dir+image_name+".jpg")
        xml_name = xml_dir+image_name+'.xml'
        with open(xml_name, 'w') as fout:
            fout.write('<annotation>'+'
    ')
            
            fout.write('	'+'<folder>VOC2007</folder>'+'
    ')
            fout.write('	'+'<filename>'+image_name+'.jpg'+'</filename>'+'
    ')
            
            fout.write('	'+'<source>'+'
    ')
            fout.write('		'+'<database>'+'VisDrone2019 Database'+'</database>'+'
    ')
            fout.write('		'+'<annotation>'+'VisDrone2019'+'</annotation>'+'
    ')
            fout.write('		'+'<image>'+'flickr'+'</image>'+'
    ')
            fout.write('		'+'<flickrid>'+'Unspecified'+'</flickrid>'+'
    ')
            fout.write('	'+'</source>'+'
    ')
            
            fout.write('	'+'<owner>'+'
    ')
            fout.write('		'+'<flickrid>'+'Haipeng Zhang'+'</flickrid>'+'
    ')
            fout.write('		'+'<name>'+'Haipeng Zhang'+'</name>'+'
    ')
            fout.write('	'+'</owner>'+'
    ')
            
            fout.write('	'+'<size>'+'
    ')
            fout.write('		'+'<width>'+str(img.size[0])+'</width>'+'
    ')
            fout.write('		'+'<height>'+str(img.size[1])+'</height>'+'
    ')
            fout.write('		'+'<depth>'+'3'+'</depth>'+'
    ')
            fout.write('	'+'</size>'+'
    ')
            
            fout.write('	'+'<segmented>'+'0'+'</segmented>'+'
    ')
    
            for line in fin.readlines():
    
                line = line.split(',')
                if int(line[5])==0 or int(line[5])==11:
                    continue
                fout.write('	'+'<object>'+'
    ')
                print(line)
                print(image_name)
                fout.write('		'+'<name>'+class_name[int(line[5])]+'</name>'+'
    ')
                fout.write('		'+'<pose>'+'Unspecified'+'</pose>'+'
    ')
                fout.write('		'+'<truncated>'+line[6]+'</truncated>'+'
    ')
                fout.write('		'+'<difficult>'+str(int(line[7]))+'</difficult>'+'
    ')
                fout.write('		'+'<bndbox>'+'
    ')
                fout.write('			'+'<xmin>'+line[0]+'</xmin>'+'
    ')
                fout.write('			'+'<ymin>'+line[1]+'</ymin>'+'
    ')
                # pay attention to this point!(0-based)
                fout.write('			'+'<xmax>'+str(int(line[0])+int(line[2])-1)+'</xmax>'+'
    ')
                fout.write('			'+'<ymax>'+str(int(line[1])+int(line[3])-1)+'</ymax>'+'
    ')
                fout.write('		'+'</bndbox>'+'
    ')
                fout.write('	'+'</object>'+'
    ')
                 
            fin.close()
            fout.write('</annotation>')
    
    • 设置mmdetection2训练自己的数据集

      • 更改config.py
        这里我选择的是./configs/cascade_rcnn/cascade_rcnn_r101_fpn_1x_coco.py的config,发现其是调用./configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py。
        由于本次使用VOC数据格式,故更改cascade_rcnn_r50_fpn_1x_coco.py的datasets为../-base-/datasets/voc0712.py

      • 更改datasets的config.py
        数据config在configs/-base-/datasets/voc0712.py,中,需要更改输入尺寸,数据集路径,batchsize,具体如下图:

      • 更改学习率
        学习率在configs/-base-/schedules/schedule_1x.py中更改,单卡训练学习率推荐值为0.02*batchsize/16,具体如下图:

      • 更改类别数
        cascade_rcnn_r101_fpn_1x_coco.py的类别数在/configs/-base-/models/cascade_rcnn_r50_fpn.py中更改。将其中的num_classes改为需要的类别数,mmdetection2不需要再加1了,此处设为10

      • 更改类别名
        有两个地方需要更改类别名,首先是mmdet/datatsets/voc.py中的class VOCDatasets,如下图

        另外是计算mAP的地方需要更改类别,具体在/mmdet/core/evaluation/class_names.py

    • 训练
      为了方便训练,写一个小脚本mytrain.sh来训练

    #!/bin/bash
    CUDA_VISIBLE_DEVICES=6 python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py 
    

    然后运行sh mytrain.sh

    • 测试
      为了方便测试,写一个小脚本mytest.sh来测试
    #!/bin/bash
    CUDA_VISIBLE_DEVICES=6 
    python tools/test.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_12.pth --out results.pkl --eval bbox --show 
    python tools/voc_eval.py results.pkl ./configs/my_data.py
    
    • 有时候改完config后还是不能运行报错:of MMDataParallel does not matches the length of CLASSES 20) in RepeatDataset。很有可能是需要重构代码:

    运行

    python setup.py develop
    
  • 相关阅读:
    ACM: SCU 4440 Rectangle
    ACM: NBUT 1646 Internet of Lights and Switches
    ACM: Long Live the Queen
    ACM: Racing Gems
    C++ 11 笔记 (一) : lambda
    cocos2d-x笔记2: 编译到安卓的步骤与注意事项
    C++笔记1: 单例模式。(一个简单的设计模式在C++中复杂出翔。。)
    Java笔记2 : 泛型的体现,及其上限、下限、通配符
    我终于忍不住喷一下某些书了,关于Java传引用的XX言论
    Java笔记1 : 在生产者消费者模式中,线程通信与共享数据,死锁问题与解决办法
  • 原文地址:https://www.cnblogs.com/peng-yuan/p/14662919.html
Copyright © 2011-2022 走看看