zoukankan      html  css  js  c++  java
  • Caffe学习系列(四)之--训练自己的模型

    前言:

        本文章记录了我将自己的数据集处理并训练的流程,帮助一些刚入门的学习者,也记录自己的成长,万事起于忽微,量变引起质变。

    正文:

    一、流程

        1)准备数据集

        2)数据转换为lmdb格式

        3)计算均值并保存(非必需)

        4)创建模型并编写配置文件

        5)训练和测试

    二、实施

    (一)准备数据集

           在深度学习中,数据集准备往往是最难的事情,因为数据涉及隐私、商业等各方面,获取难度很大,不过有很多科研机构公布了供学习使用的数据集,我们可以在网上下载。还有一种获取的途径是论文,查阅国内外相关的论文,看他们是如何获取到数据集的,我们也可以使用他 们所采用的数据集。

    我要训练的模型是人脸识别,训练的数据集是在网上下载的,经过整理,在我的网盘可以下载:http://pan.baidu.com/s/1jIxCcKI

    (二)数据转换为lmdb格式

        生成lmdb格式的文件通过脚本来实现,这就需要我们自己编写脚本文件,这里遇到了一些坑,首先使用vim创建脚本文件create1.sh

    #!/usr/bin/env sh
    DATA=AR1
    MY=newfile
    echo "Create train.txt..."
    rm -rf $MY/train.txt
    for i in 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    do
    find $DATA/train/$i -name *.pgm|cut -d '/' -f2-4 | sed "s/$/ $i/">>$MY/train.txt
    done
    echo "Create test.txt..."
    rm -rf $MY/test.txt
    for i in 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    do
    find $DATA/test/$i -name *.pgm|cut -d '/' -f2-4 | sed "s/$/ $i/">>$MY/test.txt
    done
    echo "All done"

    这个脚本文件中,用到了rm,find, cut, sed,cat等linux命令。

    rm: 删除文件

    find: 寻找文件

    cut: 截取路径

    sed: 在每行的最后面加上标注。本例中将找到的*cat.jpg文件加入标注为1,找到的*bike.jpg文件加入标注为2

    cat: 将两个类别合并在一个文件里。

    执行这个脚本:

    sh data/face/create1.sh

       成功的话就会在newfile文件夹里生成train.txt和test.txt文件,比如

        

    f2-4的含义是选取以“/”而分隔开的第2至第4部分

    接着再编写一个脚本文件,调用convert_imageset命令来转换数据格式。

    vim lmdb.sh
    #!/usr/bin/env sh
    MY=data/face/newfile
    echo "Create train lmdb.."
    rm -rf $MY/img_train_lmdb
    build/tools/convert_imageset --shuffle 
    --resize_height=256 
    --resize_width=256 
    /home/zyf/ygh/project/caffe/data/face/AR1/ $MY/train.txt $MY/img_train_lmdb
    echo "Create test lmdb.."
    rm -rf $MY/img_test_lmdb
    build/tools/convert_imageset 
    --shuffle 
    --resize_width=256 
    --resize_height=256 
    /home/zyf/ygh/project/caffe/data/face/AR1/ 
    $MY/test.txt 
    $MY/img_test_lmdb
    echo "All Done.."

    我统一转换成256*256大小。

    sh lmdb.sh

    运行成功后,会在 newfile下面生成两个文件夹img_train_lmdb和img_test_lmdb,分别用于保存图片转换后的lmdb文件。

    (三)计算均值并保存(非必需)

    图片减去均值再训练,会提高训练速度和精度。因此,一般都会有这个操作。

    caffe程序提供了一个计算均值的文件compute_image_mean.cpp,我们直接使用就可以了

    build/tools/compute_image_mean data/face/newfile/img_train_lmdb data/face/newfile/mean.binaryproto 
    
    
    compute_image_mean带两个参数,第一个参数是lmdb训练数据位置,第二个参数设定均值文件的名字及保存路径。运行成功后,会在 newfile/ 下面生成一个mean.binaryproto的均值文件。

    (四)创建模型并编写配置文件

    模型里面的数据

    data_param {
        source: "data/face/newfile/img_train_lmdb"
        backend:LMDB
        batch_size: 128
      }
    
    
    transform_param {
         mean_file: "data/face/newfile/mean.binaryproto"
         mirror: true
      }

    这其中的source和mean_file的路径要改成前面你自己生成的文件目录,其余的不需要修改,我这里采用的是网上训练精度不错的一个网络,具体下载可以转到百度云:  链接

     其中的train_val.prototxt是训练网络

    然后修改其中的solver.prototxt

    net: "data/face/train_val.prototxt"
    test_iter: 10
    test_interval: 100
    
    base_lr: 0.001
    lr_policy: "step"
    gamma: 0.95
    stepsize:  100
    momentum: 0.9
    weight_decay: 0.0005
    
    display: 100
    max_iter:  5000
    snapshot:  5000
    snapshot_prefix: "data/face"
    solver_mode: GPU
    device_id:0
    #debug_info: true
    其中test_iter: 10,test_interval: 100,一千张图片每次测试100张,10次就都可以覆盖了。。在训练过程中,调整学习率,逐步变小。

    (五)训练和测试

     如果前面都没有问题,数据准备好了,配置文件也配置好了,这一步就比较简单了。

    build/tools/caffe train -solver data/face/solver.prototxt

    直接训练即可,可以实时在命令行下查看其精度与loss。

     

    待续。。。

                                                                         by  still

     

  • 相关阅读:
    【arc072f】AtCoder Regular Contest 072 F
    maven settings解决下载不了依赖包问题
    git 命令提交本地代码到新创建的仓库
    JAVA 利用切面、注解 动态判断请求信息中字段是否需要为空
    JAVA 根据身份证号码解析出生日期、性别、年龄
    利用JAVA正则快速获取URL的文件名
    datalist
    Mybatis map接收list参数
    bootstrap-table 列宽动态拖拽改变宽度
    JAVA 枚举类遍历与switch使用
  • 原文地址:https://www.cnblogs.com/ygh1229/p/6724392.html
Copyright © 2011-2022 走看看