zoukankan      html  css  js  c++  java
  • caffe简易上手指南(二)—— 训练我们自己的数据

    训练我们自己的数据

     

    本篇继续之前的教程,下面我们尝试使用别人定义好的网络,来训练我们自己的网络。

    1、准备数据

    首先很重要的一点,我们需要准备若干种不同类型的图片进行分类。这里我选择从ImageNet上下载了3个分类的图片(Cat,Dog,Fish)。

    图片需要分两批:训练集(train)、测试集(test),一般训练集与测试集的比例大概是5:1以上,此外每个分类的图片也不能太少,我这里每个分类大概选了5000张训练图+1000张测试图。

    找好图片以后,需要准备以下文件:

    words.txt:分类序号与分类对应关系(注意:要从0开始标注

    0 cat
    1 dog
    2 fish

    train.txt:标明训练图片路径及其对应分类,路径和分类序号直接用空格分隔,最好随机打乱一下图片

    /opt/caffe/examples/my_simple_image/data/cat_train/n02123045_4416.JPEG 0
    /opt/caffe/examples/my_simple_image/data/cat_train/n02123045_3568.JPEG 0
    /opt/caffe/examples/my_simple_image/data/fish_train/n02512053_4451.JPEG 2
    /opt/caffe/examples/my_simple_image/data/cat_train/n02123045_3179.JPEG 0
    /opt/caffe/examples/my_simple_image/data/cat_train/n02123045_6956.JPEG 0
    /opt/caffe/examples/my_simple_image/data/cat_train/n02123045_10143.JPEG 0
    ......

    val.txt:标明测试图片路径及其对应分类

    /opt/caffe/examples/my_simple_image/data/dog_val/n02084071_12307.JPEG 1
    /opt/caffe/examples/my_simple_image/data/dog_val/n02084071_10619.JPEG 1
    /opt/caffe/examples/my_simple_image/data/cat_val/n02123045_13360.JPEG 0
    /opt/caffe/examples/my_simple_image/data/cat_val/n02123045_13060.JPEG 0
    /opt/caffe/examples/my_simple_image/data/cat_val/n02123045_11859.JPEG 0
    ......

    2、生成lmdb文件

    lmdb是caffe使用的一种输入数据格式,相当于我们把图片及其分类重新整合一下,变成一个数据库输给caffe训练。

    这里我们使用caffenet的create_imagenet.sh文件修改,主要是重新指定一下路径:

    EXAMPLE=examples/my_simple_image/
    DATA=examples/my_simple_image/data/
    TOOLS=build/tools
    
    TRAIN_DATA_ROOT=/
    VAL_DATA_ROOT=/
    
    # 这里我们打开resize,需要把所有图片尺寸统一
    RESIZE=true
    if $RESIZE; then
      RESIZE_HEIGHT=256
      RESIZE_WIDTH=256
    else
      RESIZE_HEIGHT=0
      RESIZE_WIDTH=0
    fi
    
    .......
    
    echo "Creating train lmdb..."
    
    GLOG_logtostderr=1 $TOOLS/convert_imageset 
        --resize_height=$RESIZE_HEIGHT 
        --resize_width=$RESIZE_WIDTH 
        --shuffle 
        $TRAIN_DATA_ROOT 
        $DATA/train.txt 
        $EXAMPLE/ilsvrc12_train_lmdb  #生成的lmdb路径
    
    echo "Creating val lmdb..."
    
    GLOG_logtostderr=1 $TOOLS/convert_imageset 
        --resize_height=$RESIZE_HEIGHT 
        --resize_width=$RESIZE_WIDTH 
        --shuffle 
        $VAL_DATA_ROOT 
        $DATA/val.txt 
        $EXAMPLE/ilsvrc12_val_lmdb    #生成的lmdb路径
    echo "Done."

    3、生成mean_file

    下面我们用lmdb生成mean_file,用于训练(具体做啥用的我还没研究。。。)

    这里也是用imagenet例子的脚本:

    EXAMPLE=examples/my_simple_image
    DATA=examples/my_simple_image
    TOOLS=build/tools
    
    $TOOLS/compute_image_mean $EXAMPLE/ilsvrc12_train_lmdb $DATA/imagenet_mean.binaryproto
    
    echo "Done."

    4、修改solver、train_val配置文件

    这里我们可以选用cifar的网络,也可以用imagenet的网络,不过后者的网络结构更复杂一些,为了学习,我们就用cifar的网络来改。

    把cifar的两个配置文件拷过来:

    cifar10_quick_solver.prototxt
    cifar10_quick_train_test.prototxt

    首先修改cifar10_quick_train_test.prototxt的路径以及输出层数量(标注出黑体的部分):

    name: "CIFAR10_quick"
    layer {
      name: "cifar"
      type: "Data"
      top: "data"
      top: "label"
      include {
        phase: TRAIN
      }
      transform_param {
        mean_file: "examples/my_simple_image/imagenet_mean.binaryproto"
      }
      data_param {
    source: "examples/my_simple_image/ilsvrc12_train_lmdb" batch_size: 50 #一次训练的图片数量,一般指定50也够了 backend: LMDB } } layer { name: "cifar" type: "Data" top: "data" top: "label" include { phase: TEST } transform_param { mean_file: "examples/my_simple_image/imagenet_mean.binaryproto" } data_param { source: "examples/my_simple_image/ilsvrc12_val_lmdb" batch_size: 50 #一次训练的图片数量 backend: LMDB } }
    ..........
    layer { name: "ip2" type: "InnerProduct" bottom: "ip1" top: "ip2" .......... inner_product_param { num_output: 3 #输出层数量,就是你要分类的个数 weight_filler { type: "gaussian" std: 0.1 } bias_filler { type: "constant" } } } ......

    cifar10_quick_solver.prototxt的修改根据自己的实际需要:

    net: "examples/my_simple_image/cifar/cifar10_quick_train_test.prototxt"   #网络文件路径
    test_iter: 20 #测试执行的迭代次数
    test_interval: 10 #迭代多少次进行测试 base_lr: 0.001 #迭代速率,这里我们改小了一个数量级,因为数据比较少
    momentum: 0.9 weight_decay: 0.004 lr_policy: "fixed" #采用固定学习速率的模式display: 1 #迭代几次就显示一下信息,这里我为了及时跟踪效果,改成1 max_iter: 4000 #最大迭代次数 snapshot: 1000 #迭代多少次生成一次快照 snapshot_prefix: "examples/my_simple_image/cifar/cifar10_quick" #快照路径和前缀 solver_mode: CPU #CPU或者GPU

    5、开始训练

    运行下面的命令,开始训练(为了方便可以做成脚本)

    ./build/tools/caffe train --solver=examples/my_simple_image/cifar/cifar10_quick_solver.prototxt

    6、小技巧

    网络的配置和训练其实有一些小技巧。

    - 训练过程中,正确率时高时低是很正常的现象,但是总体上是要下降的

    - 观察loss值的趋势,如果迭代几次以后一直在增大,最后变成nan,那就是发散了,需要考虑减小训练速率,或者是调整其他参数

    - 数据不能太少,如果太少的话很容易发散

     

  • 相关阅读:
    SAP OPEN UI5 Step 8: Translatable Texts
    SAP OPEN UI5 Step7 JSON Model
    SAP OPEN UI5 Step6 Modules
    SAP OPEN UI5 Step5 Controllers
    SAP OPEN UI5 Step4 Xml View
    SAP OPEN UI5 Step3 Controls
    SAP OPEN UI5 Step2 Bootstrap
    SAP OPEN UI5 Step1 环境安装和hello world
    2021php最新composer的使用攻略
    Php使用gzdeflate和ZLIB_ENCODING_DEFLATE结果gzinflate报data error
  • 原文地址:https://www.cnblogs.com/alexcai/p/5469436.html
Copyright © 2011-2022 走看看