zoukankan      html  css  js  c++  java
  • Chapter 3 Start Caffe with MNIST Demo

    先从一个具体的例子来开始Caffe,以MNIST手写数据为例。 

    1.下载数据

    下载mnist到caffe-masterdatamnist文件夹。

    THE MNIST DATABASE:Yann LeCun et al.

     train-images-idx3-ubyte.gz:  training set images (9912422 bytes)

     train-labels-idx1-ubyte.gz:  training set labels (28881 bytes)

     t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes)

     t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

    2.生成lmdb文件

    使用convert_mnist_data project转换数据。

    打开Caffe.sln,设置convert_mnist_data为启动项目,修改convert_mnist_data.cpp中代码。

    在main函数中设置了转换数据的路径,具体的代码如下:

        //get mnist train and test lmdb data By Xiaopan Lyu=====================
        argc = 4;
        argv[0] = "lmdb";
        //convert train mnist data=============================================    
        argv[1] = "../../data/mnist/train-images.idx3-ubyte";
        argv[2] = "../../data/mnist/train-labels.idx1-ubyte";
        argv[3] = "../../data/mnist/mnist_train_lmdb";
     
        //convert test mnist data=============================================    
        argv[1] = "../../data/mnist/t10k-images.idx3-ubyte";
        argv[2] = "../../data/mnist/t10k-labels.idx1-ubyte";
        argv[3] = "../../data/mnist/mnist_test_lmdb";
        //======================================================================

    这段代码在main函数中的位置如下:

    int main(int argc, char** argv) {
    #ifndef GFLAGS_GFLAGS_H_
        namespace gflags = google;
    #endif
     
        FLAGS_alsologtostderr = 1;
     
        //get mnist train and test lmdb data By Xiaopan Lyu=====================
        argc = 4;
        argv[0] = "lmdb";
        //convert train mnist data=============================================    
        argv[1] = "../../data/mnist/train-images.idx3-ubyte";
        argv[2] = "../../data/mnist/train-labels.idx1-ubyte";
        argv[3] = "../../data/mnist/mnist_train_lmdb";
     
        //convert test mnist data=============================================    
        argv[1] = "../../data/mnist/t10k-images.idx3-ubyte";
        argv[2] = "../../data/mnist/t10k-labels.idx1-ubyte";
        argv[3] = "../../data/mnist/mnist_test_lmdb";
        //======================================================================
     
        gflags::SetUsageMessage("This script converts the MNIST dataset to
    "
            "the lmdb/leveldb format used by Caffe to load data.
    "
            "Usage:
    "
            "    convert_mnist_data [FLAGS] input_image_file input_label_file "
            "output_db_file
    "
            "The MNIST dataset could be downloaded at
    "
            "    http://yann.lecun.com/exdb/mnist/
    "
            "You should gunzip them after downloading,"
            "or directly use data/mnist/get_mnist.sh
    ");
        gflags::ParseCommandLineFlags(&argc, &argv, true);
     
        const string& db_backend = FLAGS_backend;
     
        if (argc != 4) {
            gflags::ShowUsageWithFlagsRestrict(argv[0],
                "examples/mnist/convert_mnist_data");
        }
        else {
            google::InitGoogleLogging(argv[0]);
            convert_dataset(argv[1], argv[2], argv[3], db_backend);
        }
        system("pause");
        return 0;
    }

    两次运行代码,分别得到train和test data。

    get mnist_train_lmdb

    QQ截图20160817163114

    get mnist_test_lmdb

    QQ截图20160817163205

    Notes.

    1)argv[0]、argv[1]、argv[2]、argv[3]分别表示的含义:[FLAGS]     input_image_file    input_label_file    output_db_file.

    2)output_db_file设置中最后的一级文件夹不要事先自己建立好,代码中不支持覆盖,如果存在文件夹会报错。

    3)这些路径的设置是在debug模式下,文件的层级是以当前的.cpp文件为基础的,与实际EXE文件有所不同。

    3.配置网络为TRAIN模式

    1)配置lenet_train_test.prototxt

    caffe在mnist自带的是使用lenet的网络结。lenet网络的定义在examplesmnistlenet_train_test.prototxt文件中。

    注意配置好改网络定义中的数据路径。如下所示,注意line 14&31,如果目录比较混淆,可以直接写成绝对路径。

       1:  name: "LeNet"
       2:  layer {
       3:    name: "mnist"
       4:    type: "Data"
       5:    top: "data"
       6:    top: "label"
       7:    include {
       8:      phase: TRAIN
       9:    }
      10:    transform_param {
      11:      scale: 0.00390625
      12:    }
      13:    data_param {
      14:      source: "E:/MyCode/DL/caffe-master/examples/mnist/mnist_train_lmdb"
      15:      batch_size: 64
      16:      backend: LMDB
      17:    }
      18:  }
      19:  layer {
      20:    name: "mnist"
      21:    type: "Data"
      22:    top: "data"
      23:    top: "label"
      24:    include {
      25:      phase: TEST
      26:    }
      27:    transform_param {
      28:      scale: 0.00390625
      29:    }
      30:    data_param {
      31:      source: "E:/MyCode/DL/caffe-master/examples/mnist/mnist_test_lmdb"
      32:      batch_size: 100
      33:      backend: LMDB
      34:    }
      35:  }

    2)配置lenet_solver.prototxt

    lenet_solver.prototxt中实际上是定义了一种解决方案。

    注意line 2,23&25,这三行的数据需要修改,这里也是用了绝对路径。只使用CPU训练。

       1:  # The train/test net protocol buffer definition
       2:  net: "E:/MyCode/DL/caffe-master/examples/mnist/lenet_train_test.prototxt"
       3:  # test_iter specifies how many forward passes the test should carry out.
       4:  # In the case of MNIST, we have test batch size 100 and 100 test iterations,
       5:  # covering the full 10,000 testing images.
       6:  test_iter: 100
       7:  # Carry out testing every 500 training iterations.
       8:  test_interval: 500
       9:  # The base learning rate, momentum and the weight decay of the network.
      10:  base_lr: 0.01
      11:  momentum: 0.9
      12:  weight_decay: 0.0005
      13:  # The learning rate policy
      14:  lr_policy: "inv"
      15:  gamma: 0.0001
      16:  power: 0.75
      17:  # Display every 100 iterations
      18:  display: 100
      19:  # The maximum number of iterations
      20:  max_iter: 10000
      21:  # snapshot intermediate results
      22:  snapshot: 5000
      23:  snapshot_prefix: "E:/MyCode/DL/caffe-master/examples/mnist/lenet"
      24:  # solver mode: CPU or GPU
      25:  solver_mode: CPU

    3)修改了source code为train模式

    修改了caffe.cpp文件的相关内容。增加了line15到line21的代码,顺便说一句Google的gflags解析命令行参数甚是优雅。

       1:  int main(int argc, char** argv) {
       2:      // Print output to stderr (while still logging).
       3:      FLAGS_alsologtostderr = 1;
       4:      // Set version
       5:      gflags::SetVersionString(AS_STRING(CAFFE_VERSION));
       6:      // Usage message.
       7:      gflags::SetUsageMessage("command line brew
    "
       8:          "usage: caffe <command> <args>
    
    "
       9:          "commands:
    "
      10:          "  train           train or finetune a model
    "
      11:          "  test            score a model
    "
      12:          "  device_query    show GPU diagnostic information
    "
      13:          "  time            benchmark model execution time");
      14:      // Run tool or show usage.
      15:      //train lenet By XiaopanLyu====================================================
      16:      argc = 3;
      17:      argv[0] = "caffe";
      18:      argv[1] = "train";
      19:      argv[2] = "-solver=E:/MyCode/DL/caffe-master/examples/mnist/lenet_solver.prototxt";
      20:      //argv[1] = "solver=../examples/mnist/lenet_solver.prototxt";
      21:      //=============================================================================
      22:      caffe::GlobalInit(&argc, &argv);
      23:      if (argc == 2) {
      24:  #ifdef WITH_PYTHON_LAYER
      25:          try {
      26:  #endif
      27:              return GetBrewFunction(caffe::string(argv[1]))();
      28:              system("pause");
      29:  #ifdef WITH_PYTHON_LAYER
      30:          }
      31:          catch (bp::error_already_set) {
      32:              PyErr_Print();
      33:              return 1;
      34:          }
      35:  #endif
      36:      }
      37:      else {
      38:          gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
      39:      }
      40:  }

    4.Training LeNet

    运行caffe project,mnist demo就开始运行了,以下就是运行的过程和结果。

    1)运行过程

    QQ截图20160817181544

    2)运行结果

    运行完成后,生成了4个文件。查看lenet_solver.prototxt可知,最大迭代次数为10000次,5000次保存一次快照结果。

    QQ截图20160818165202

    5.配置网络为TEST模式

    修改caffe.cpp文件,增加参数配置代码

       1:      //test lenet By XiaopanLyu====================================================
       2:      argc = 5;
       3:      argv[0] = "caffe";
       4:      argv[1] = "test";
       5:      argv[2] = "-model=E:/MyCode/DL/caffe-master/examples/mnist/lenet_train_test.prototxt";
       6:      argv[3] = "-weights=E:/MyCode/DL/caffe-master/examples/mnist/lenet_iter_10000.caffemodel";
       7:      argv[4] = "-iterations=100";
       8:      //=============================================================================
       9:      caffe::GlobalInit(&argc, &argv);

    6.Testing LeNet

    用LeNet的网络配置运行mnist测试数据集,几分钟的时间得到如下效果。

    QQ截图20160818173929

    迭代100次,测试数据集的准确率为99.02%。

    7.NOTES

    在第3、5部分,配置网络的参数可以参考Caffe的官方辅导文档:http://caffe.berkeleyvision.org/tutorial/interfaces.html

  • 相关阅读:
    JS 页面截屏,转为图片
    php js 交互(js调用PHP代码执行)
    微信开发,自定义菜单不生效怎么办?重新关注也无效
    ios 带scrollView的控制器,双击“状态栏”,返回scrollView的顶部
    iOS 文件共享 ,通过手机助手/mac 访问APP沙盒
    cell 各自的高度不同的时候
    释放控制器。控制器的生命周期,有 定时器的 控制器
    TmpCode
    ios uploader 上传IPA到itunes
    UIImageView的image的图片显示 imageView.contentMode
  • 原文地址:https://www.cnblogs.com/xiaopanlyu/p/5780538.html
Copyright © 2011-2022 走看看