zoukankan      html  css  js  c++  java
  • gluoncv 目标检测,训练自己的数据集

    https://gluon-cv.mxnet.io/build/examples_datasets/detection_custom.html

    官方提供两种方案,一种是lst文件,一种是xml文件(voc的格式);

    voc 格式的标注有标注工具,但是你如果是json文件标注的信息,或者其他格式的,你就要转成voc格式的。

    于是就选择第一种数据格式lst序列文件格式,格式很简单。

    根据你自己的json或者其他格式文件转换一下。

    import json
    import os
    import cv2
    import numpy as np
    
    
    def write_line(img_path, im_shape, boxes, ids, idx):
        h, w, c = im_shape
        # for header, we use minimal length 2, plus width and height
        # with A: 4, B: 5, C: width, D: height
        A = 4
        B = 5
        C = w
        D = h
        # concat id and bboxes
        labels = np.hstack((ids.reshape(-1, 1), boxes)).astype('float')
        # normalized bboxes (recommanded)
        labels[:, (1, 3)] /= float(w)
        labels[:, (2, 4)] /= float(h)
        # flatten
        labels = labels.flatten().tolist()
        str_idx = [str(idx)]
        str_header = [str(x) for x in [A, B, C, D]]
        str_labels = [str(x) for x in labels]
        str_path = [img_path]
        line = '	'.join(str_idx + str_header + str_labels + str_path) + '
    '
        return line
    
    
    files = os.listdir('train_front')
    json_url = []
    cnt = 0
    for file in files:
        tmp = os.listdir('train_front/'+file)
        for js in tmp:
            if js.endswith('json'):
                json_url.append('train_front/'+file+'/'+js)
                cnt+=1
    print(cnt)
    
    fwtrain = open("train.lst","w")
    fwval = open("val.lst","w")
    
    first_flag = []
    flag = True
    
    cnt = 0
    cnt1 = 0
    cnt2 = 0
    for json_url_index in json_url:
        file = open(json_url_index,'r')
        for line in file:
            js = json.loads(line)
    
            if 'person' in js:
                boxes = []
                ids = []
                for i in range(len(js['person'])):
                    if js['person'][i]['attrs']['ignore'] == 'yes' or js['person'][i]['attrs']['occlusion']== 'heavily_occluded' or js['person'][i]['attrs']['occlusion']== 'invisible':
                        continue
    
    
                    bbox = js['person'][i]['data']
                    url = '/mnt/hdfs-data-4/data/jian.yin/'+json_url_index[:-5]+'/'+js['image_key']
                    width = js['width']
                    height = js['height']
                    boxes.append(bbox)
                    ids.append(0)
    
                    print(url)
                    print(bbox)
    
                if len(boxes) > 0:
                    if flag:
                        flag = False
                        first_flag = boxes
                    ids = np.array(ids)
    
                    if cnt < 27853//2:
    
                        line = write_line(url,(height,width,3),boxes,ids,cnt1)
                        fwtrain.write(line)
                        cnt1+=1
    
                    if cnt >= 27853//2:
                        line = write_line(url, (height, width, 3), boxes, ids, cnt2)
                        fwval.write(line)
                        cnt2+=1
    
                    cnt += 1
    
    fwtrain.close()
    fwval.close()
    print(first_flag)

    lst文件就转换好了。

    然后添加自己的数据集:

    https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73

    这里不能直接套用前面的导入数据的过程。

    按照教程给出的方式添加。投机取巧的验证方式,直接引用前面的。

    或者不验证:https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L393 部分注释掉。

        elif dataset.lower() == 'pedestrian':
            lst_dataset = LstDetection('train_val.lst',root=os.path.expanduser('.'))
            print(len(lst_dataset))
            first_img = lst_dataset[0][0]
    
            print(first_img.shape)
            print(lst_dataset[0][1])
            
            train_dataset = LstDetection('train.lst',root=os.path.expanduser('.'))
            val_dataset = LstDetection('val.lst',root=os.path.expanduser('.'))
            classs = ('pedestrian',)
            val_metric = VOC07MApMetric(iou_thresh=0.5,class_names=classs)

    训练参数:

    https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73

    添加自己的训练参数或者直接套用。

        if args.dataset == 'voc' or args.dataset == 'pedestrian':
            args.epochs = int(args.epochs) if args.epochs else 20
            args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20'
            args.lr = float(args.lr) if args.lr else 0.001
            args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
            args.wd = float(args.wd) if args.wd else 5e-4

    model_zoo.py添加自己的数据集映射方案。这里如果是pip install gluoncv ,就要到site-package里面改。

    https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_zoo.py#L32

    'faster_rcnn_resnet50_v1b_pedestrian': faster_rcnn_resnet50_v1b_voc,

  • 相关阅读:
    zabbix添加对haproxy的监控
    【转】最近搞Hadoop集群迁移踩的坑杂记
    【转】Hive配置文件中配置项的含义详解(收藏版)
    【转】Spark-Sql版本升级对应的新特性汇总
    kylin查询出现日期对应不上的情况
    【转】saiku与kylin整合备忘录
    Eclipse中Ctrl+方法名发现无法进入到该方法中……
    maven会报Could not transfer artifact xxx错误
    【转】CDH5.x升级
    【转】Kylin实践之使用Hive视图
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10174899.html
Copyright © 2011-2022 走看看