zoukankan      html  css  js  c++  java
  • CaffeExample 在CIFAR-10数据集上训练与测试

    本文主要来自Caffe作者Yangqing Jia网站给出的examples

    @article{jia2014caffe,
      Author = {Jia, Yangqing and Shelhamer, Evan and Donahue, Jeff and Karayev, Sergey and Long, Jonathan and Girshick, Ross and Guadarrama, Sergio and Darrell, Trevor},
      Journal = {arXiv preprint arXiv:1408.5093},
      Title = {Caffe: Convolutional Architecture for Fast Feature Embedding},
      Year = {2014}
    }

    1.cuda-convnet

    采用的网络是Alex Krizhevsky的cuda-convnet,链接中详细描述了模型的定义、所用的参数、训练过程,在CIFAR-10上取得了很好的效果。

    2.数据集的准备

    本实验使用的数据集是CIFAR-10,一共有60000张32*32的彩色图像,其中50000张是训练集,另外10000张是测试集。数据集共有10个类别,分别如下所示
    图1 CIFAR-10数据集

    下面假定caffe的根目录是CAFFE_ROOT,在终端输入命令下载数据集:

    cd $CAFFE_ROOT 
    ./data/cifar10/get_cifar10.sh  #该脚本会下载二进制的cifar,并解压,会在/data/cifar10中出现很多batch文件
    ./examples/cifar10/create_cifar10.sh  #运行后将会在examples中出现数据集./cifar10_xxx_lmdb和数据集图像均值./mean.binaryproto

    3.模型

    CIFAR-10的卷积神经网络模型由卷积层,pooling层,ReLU,非线性变换层,局部对比归一化线性分类器组成。该模型定义在CAFFE_ROOT/examples/cifar10/cifar10_quick_train_test.prototxt中。

    4.训练和测试“quick”模型

    写好网络定义和solver以后,开始训练模型。输入下面的命令:

    cd $CAFFE_ROOT 
     ./examples/cifar10/train_quick.sh  #先以0.001的学习率迭代4000次,再以0.01的学习率接着再迭代1000次,共5000次

    可以看到每一层的详细信息、连接关系及输出的形式,方便调试。
    图2
    初始化后开始训练:
    图3
    在solver的设置中,每100次迭代会输出一次训练损失,测试是500次迭代输出一次:
    图4
    训练阶段,lr是学习率,loss是训练函数。测试阶段,score 0是准确率,score 1是损失函数。最后的结果:
    图5
    测试准确率大约有0.75,模型参数存储在二进制protobuf格式的文件cifar10_quick_iter_5000中。
    参考CAFFE_ROOT/examples/cifar10/cifar10_quick.prototxt的模型定义,就可以训练其他数据了。

    5.GPU使用

    CIFAR-10比较小,可以用GPU训练,当然也可以用CPU训练。为了比较CPU和GPU的训练速度,通过修改cifar*solver.prototxt中的一行代码来实现。

    # solver mode: CPU or GPU 
    solver_mode: CPU
    • 1
    • 2

    6.”full”模型

    同理可以训练full模型,full模型比quick模型迭代次数多,一共迭代70000次,前60000次学习率是0.001,中间5000次学习率是0.0001,最后5000次学习率是0.00001。full模型的网络层数也比quick模型多。
    命令是:

    cd $CAFFE_ROOT 
    ./examples/cifar10/train_full.sh
    • 1
    • 2

    测试准确率也比quick模型高,大约有0.82。
    这里写图片描述

    转自 http://blog.csdn.net/liumaolincycle/article/details/47258937

  • 相关阅读:
    神经网络损失函数公式解读
    centos 安装python PIL模块
    Centos6.8 安装dlib库时出错【升级gcc 到4.9.0以上】
    python DBUtils 线程池 连接 Postgresql(多线程公用线程池,DB-API : psycopg2)
    Postgresql 查看锁的过程
    Python yield 函数功能
    Centos6.8 安装spark-2.3.1 以及 scala-2.12.2
    Oracle---常用SQL语法和数据对象
    Oracle---number数据类型
    java框架篇---hibernate之连接池
  • 原文地址:https://www.cnblogs.com/is-Tina/p/7705391.html
Copyright © 2011-2022 走看看