zoukankan      html  css  js  c++  java
  • 目标检测算法SSD之训练自己的数据集

    目标检测算法SSD之训练自己的数据集

    prerequesties 预备知识/前提条件

    下载和配置了最新SSD代码

    git clone https://github.com/weiliu89/caffe ~/work/ssd
    cd $_
    git checkout ssd
    

    编译caffe

    下载必要的模型(包括prototxt和caffemodel);

    运行了evaluation和webcam的例子,会提示caffe的import报错。添加pycaffe路径到PYTHONPATH环境变量,或者写一个_init_paths.py来辅助引入都可以(推荐后者)。

    准备自己的数据集

    做成VOC2007格式的:

    JPEGImages/*.png
    ImageSets/Main/*.txt
    Annotations/*.xml
    

    这3个目录

    生成训练用的lmdb数据

    我这里数据集名叫traffic_sign,放在/home/chris/data/traffic_sign

    1.复制原有脚本文件

    cd ~/work/ssd
    cp -R data/VOC0712 data/traffic_sign
    

    2.修改data/traffic_sign/create_list.sh

    #!/bin/bash
    
    #root_dir=$HOME/data/VOCdevkit/
    root_dir=$HOME/data/
    sub_dir=ImageSets/Main
    bash_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" #当前文件所在目录
    for dataset in train test
    do
      dst_file=$bash_dir/$dataset.txt
      if [ -f $dst_file ]
      then
        rm -f $dst_file
      fi
      for name in traffic_sign
      do
        if [[ $dataset == "test" && $name == "VOC2012" ]]
        then
          continue
        fi
        echo "Create list for $name $dataset..."
        dataset_file=$root_dir/$name/$sub_dir/$dataset.txt
    
        img_file=$bash_dir/$dataset"_img.txt"
        cp $dataset_file $img_file
        sed -i "s/^/$name/JPEGImages//g" $img_file   #在行首插入目录名
        sed -i "s/$/.png/g" $img_file     #在行尾追加.png后缀
    
        label_file=$bash_dir/$dataset"_label.txt"
        cp $dataset_file $label_file
        sed -i "s/^/$name/Annotations//g" $label_file  #在行首插入目录名
        sed -i "s/$/.xml/g" $label_file   #在行尾追加.xml后缀
    
        paste -d' ' $img_file $label_file >> $dst_file  #img_file和label文件的对应行拼接
    
        rm -f $label_file
        rm -f $img_file
      done
    
      # Generate image name and size infomation.
      if [ $dataset == "test" ]
      then
        $bash_dir/../../build/tools/get_image_size $root_dir $dst_file $bash_dir/$dataset"_name_size.txt"
      fi
    
      # Shuffle train file.
      if [ $dataset == "train" ]
      then
        rand_file=$dst_file.random
        cat $dst_file | perl -MList::Util=shuffle -e 'print shuffle(<STDIN>);' > $rand_file
        mv $rand_file $dst_file
      fi
    done
    

    3.修改data/traffic_sign/create_data.sh

    #!/bin/bash
    
    cur_dir=$(cd $( dirname ${BASH_SOURCE[0]} ) && pwd )
    root_dir=$cur_dir/../..
    
    cd $root_dir
    
    redo=1
    data_root_dir="$HOME/data"
    #dataset_name="VOC0712"
    dataset_name="traffic_sign"
    mapfile="$root_dir/data/$dataset_name/labelmap_voc.prototxt"
    anno_type="detection"
    db="lmdb"
    min_dim=0
    max_dim=0
    width=0
    height=0
    
    extra_cmd="--encode-type=png --encoded"
    if [ $redo ]
    then
      extra_cmd="$extra_cmd --redo"
    fi
    for subset in test train
    do
      python $root_dir/scripts/create_annoset.py --anno-type=$anno_type --label-map-file=$mapfile --min-dim=$min_dim --max-dim=$max_dim --resize-width=$width --resize-height=$height --check-label $extra_cmd $data_root_dir $root_dir/data/$dataset_name/$subset.txt $data_root_dir/$dataset_name/$db/$dataset_name"_"$subset"_"$db examples/$dataset_name
    done
    

    4.修改data/traffic_sign/labelmap_voc.prototxt

    item {
      name: "none_of_the_above"
      label: 0
      display_name: "background"
    }
    item {
      name: "sign"
      label: 1
      display_name: "sign"
    }
    

    5.生成数据

    # 确保你还是在ssd代码根目录,比如我是~/work/ssd
    ./data/traffic_sign/create_list.sh
    ./data/traffic_sign/create_data.sh
    

    执行训练

    依然需要修改ssd默认的训练脚本内容,来匹配自己的数据集。

    1.复制原有训练脚本

    cd ~/work/ssd
    cd examples/ssd
    cp ssd_pascal.py ssd_traffic.py
    

    2.修改训练脚本

    编辑ssd_traffic.py内容,修改:

    • 数据集指向
      train_datatest_data , 指向examples中你的数据,例如:
    train_data = "examples/traffic_sign/traffic_sign_train_lmdb"
    test_data = "examples/traffic_sign/trainffic_sign_test_lmdb"
    

    这里很奇怪,我的examples/traffic_sign/目录下确实有这两个lmdb的文件夹,是指向~/data/traffic_sign/lmdb/目录下的两个lmdb文件夹,但是训练时提示lmdb错误。

    换成链接文件的源文件,也就是写绝对路径,就不报错了。

    • 测试图像数量

    num_test_image 该变量修改成自己数据集中测试数据图片的数量

    • 类别数

    num_classes 该变量修改成自己数据集中 标签类别数量数 + 1

    • gpu选项

    gpus = "0,1,2,3" 电脑有几个gpu就写多少个,如果有一个就写gpus="0",两个就写gpus="0,1",以此类推

    • 迭代次数
    solver_param = {
    	...
    	'stepvalue': [50000, 60000, 70000],
    	'max_iter': 70000,
    	'snapshot': 10000,
    }
    
    • 各种VOC0712换成自己数据集的名字(我的是traffic_sign)
    model_name = "VGG_traffic_sign_{}".format(job_name)
    save_dir = "models/VGGNet/traffic_sign/{}".format(job_name)
    snapshot_dir = 
    job_dir = 
    name_size_file = 
    label_map_file = 
    
    • batch_size
      比如6G显存的970显卡,跑不起来SSD。修改:
    batch_size = 16   # 32->16
    accum_batch_size = 16  # 32->16
    

    此时显存占用为4975MiB

    如果你显存很大,与其闲置不如使用它,调大batch_size即可

    • base_lr
      调整了batch_size或单纯因为数据集的原因,导致出现loss为nan的情况,考虑减小学习率,这里通过减小base_lr实现。

    3.执行训练

    cd ~/work/ssd    #务必到ssd的根目录执行
    python examples/ssd/ssd_traffic.py
    
    ## reference
    
    https://my.oschina.net/u/1046919/blog/777470
  • 相关阅读:
    javascript打开本地应用
    SDUT OJ -2892 A
    恳请CSDN的活动可以落实
    中国银联mPOS通用技术安全分析和规范解读
    UNIX环境编程学习——反思认识
    STM32F407VG (五)定时器
    请求的链式处理——责任链模式
    Shredding Company (hdu 1539 dfs)
    十天精通CSS3(6)
    十天精通CSS3(5)
  • 原文地址:https://www.cnblogs.com/zjutzz/p/6845002.html
Copyright © 2011-2022 走看看