zoukankan      html  css  js  c++  java
  • (转)如何使用caffe的MATLAB接口

    编译MatCaffe

    转自: http://blog.csdn.net/ws_20100/article/details/50525879

    使用如下命令编译MatCaffe

    make all matcaffe
    • 1

    之后,你可以用以下命令测试MatCaffe:

    make mattest
    • 1

    如果你在运行上面命令时,遇到如下错误:libstdc++.so.6 version ‘GLIBCXX_3.4.15’ not found,说明你的Matlab库不匹配。你需要在启动Matlab之前运行如下命令:

    export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64
    export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6
    • 1
    • 2

    在Caffe根目录启动Matlab之后需要增加路径:

    addpath ./matlab
    • 1

    你可以使用savepath来保存Matlab搜索路径,这样下次就不用再添加路径了。


    使用MatCaffe

    MatCaffe 和 PyCaffe 的使用方法很相似。

    下面将用一个例子来解释MatCaffe的具体使用细节,假设你已经下载了BVLC CaffeNet,并且在caffe根目录启动matlab。

    model = './models/bvlc_reference_caffenet/deploy.prototxt';
    weights = './models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel';
    
    • 1
    • 2
    • 3

    1.设置模式和设备

    模式和设备的设置必须在创建一个net或solver之前。

    使用CPU:

    caffe.set_mode_cpu();
    • 1

    使用GPU并指定gpu_id:

    caffe.set_mode_gpu();
    caffe.set_device(gpu_id);
    • 1
    • 2

    2.创建一个网络并访问它的layers和blobs

    1.创建网络

    创建一个网络:

    net = caffe.Net(model, weights, 'test'); % create net and load weights
    • 1

    或者

    net = caffe.Net(model, 'test'); % create net but not load weights
    net.copy_from(weights); % load weights
    • 1
    • 2

    它可以创建一个如下的net对象:

      Net with properties:
               layer_vec: [1x23 caffe.Layer]
                blob_vec: [1x15 caffe.Blob]
                  inputs: {'data'}
                 outputs: {'prob'}
        name2layer_index: [23x1 containers.Map]
         name2blob_index: [15x1 containers.Map]
             layer_names: {23x1 cell}
              blob_names: {15x1 cell}
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    两个containers.Map对象可以通过layer或者blob的名称找到对应的索引。

    2.访问blob

    你可以访问网络中的每一个blob,将data的blob填充为全一:

    net.blobs('data').set_data(ones(net.blobs('data').shape));
    • 1

    data的blob中数值全部乘以10:

    net.blobs('data').set_data(net.blobs('data').get_data() * 10);
    • 1

    注意:因为Matlab是以1作为起始单元,且以列为主,在Matlab中使用四维blob为[width, height, channels, num],且width是最快的维度,而且要在BGR通道。而且Caffe使用单精度浮点型数据。如果你的数据不是浮点型的,set_data将会自动转换为single。

    3.访问layer

    你也可以访问网络的每一层,以便你作一些网络调整。例如把conv1参数乘以10:

    net.params('conv1', 1).set_data(net.params('conv1', 1).get_data() * 10); % set weights
    net.params('conv1', 2).set_data(net.params('conv1', 2).get_data() * 10); % set bias
    • 1
    • 2

    你也可以如下代码:

    net.layers('conv1').params(1).set_data(net.layers('conv1').params(1).get_data() * 10);
    net.layers('conv1').params(2).set_data(net.layers('conv1').params(2).get_data() * 10);
    • 1
    • 2

    4.保存网络

    你仅仅需要如下代码保存网络:

    net.save('my_net.caffemodel');
    • 1

    5.获得一层的类型(string)

    layer_type = net.layers('conv1').type;
    • 1

    3.前向和后向计算

    前向和后向计算可以使用net.forward或者net.forward_prefilled实现。函数net.forward将一个包含输入blob(s)的cell数组作为输入,并输出一个包含输出blob(s)的cell数组。函数net.forward_prefilled将使用输入blob(s)中的已有数据进行计算,没有输入数据,没有输出数据。

    在通过一些方法(如:data = rand(net.blobs('data').shape);)产生输入数据后,你可以运行:

    res = net.forward({data});
    prob = res{1};
    • 1
    • 2

    或者

    net.blobs('data').set_data(data);
    net.forward_prefilled();
    prob = net.blobs('prob').get_data();
    • 1
    • 2
    • 3

    后向计算使用net.backward或者net.backward_prefilled,并且把get_dataset_data替换为get_diffset_diff。在通过一些方法(例如prob_diff = rand(net.blobs('prob').shape);)产生输出blobs的梯度后,你可以运行:

    res = net.backward({prob_diff});
    data_diff = res{1};
    • 1
    • 2

    或者

    net.blobs('prob').set_diff(prob_diff);
    net.backward_prefilled();
    data_diff = net.blobs('data').get_diff();
    • 1
    • 2
    • 3

    然而,如上的后向计算并不能得到正确的结果,因为Caffe默认网络不需要后向计算。为了获取正确的后向计算结果,你需要在你的网络prototxt文件中设置force_backward: true

    在完成前向和后向计算之后,你可以获得中间blobs的data和diff。例如,你可以在前向计算后获取pool5的特征。

    4.Reshape

    假设你想要运行1幅图像,而不是10幅时:

    net.blobs('data').reshape([227 227 3 1]); % reshape blob 'data'
    net.reshape();
    • 1
    • 2

    然后,整个网络就reshape了,此时net.blobs('prob').shape应该是[1000 1];

    5.训练网络

    假设你按照ImageNET Tutorial的方法创建了训练lmdb和验证lmdb,产生一个solver并且在ILSVRC 2012 分类数据集上训练:

    solver = caffe.Solver('./models/bvlc_reference_caffenet/solver.prototxt');
    • 1

    这样可以创建一个solver对象:

      Solver with properties:
    
              net: [1x1 caffe.Net]
        test_nets: [1x1 caffe.Net]
    • 1
    • 2
    • 3
    • 4

    训练代码:

    solver.solve();
    • 1

    如果只想训练迭代1000次:

    solver.step(1000);
    • 1

    来获取迭代数量:

    iter = solver.iter();
    • 1

    来获取这个网络:

    train_net = solver.net;
    test_net = solver.test_nets(1);
    • 1
    • 2

    假设从一个snapshot中恢复网络训练:

    solver.restore('your_snapshot.solverstate');
    • 1

    6.输入和输出

    caffe.io类提供了基本的输入函数load_imageread_mean。例如,读取ILSVRC 2012 mean文件(假设你已经通过运行./data/ilsvrc12/get_ilsvrc_aux.sh下载imagenet例程辅助文件)

    mean_data = caffe.io.read_mean('./data/ilsvrc12/imagenet_mean.binaryproto');
    • 1

    为了读取Caffe例程图片,并且resize到[width, height],且假设width = 256; height = 256;

    im_data = caffe.io.load_image('./examples/images/cat.jpg');
    im_data = imresize(im_data, [width, height]); % resize using Matlab's imresize
    • 1
    • 2

    注意:width是最快的维度,通道为BGR,与Matlab存取图片的一般方式不一样。如果你不想要使用caffe.io.load_image,且想自己导入一幅图片:

    im_data = imread('./examples/images/cat.jpg'); % read image
    im_data = im_data(:, :, [3, 2, 1]); % convert from RGB to BGR
    im_data = permute(im_data, [2, 1, 3]); % permute width and height
    im_data = single(im_data); % convert to single precision
    • 1
    • 2
    • 3
    • 4

    你也可以看一下caffe/matlab/demo/classification_demo.m文件,了解如何将输入图片crop成多个输入块。

    你可以查看caffe/matlab/hdf5creation,了解如何通过Matlab读和写HDF5数据。但不提供额外的数据输出函数,因为在Matlab本身已经具有了强大的功能。

    7.清除nets和solvers

    调用caffe.reset_all()来清理你所创建的所有的solvers,和stand-alone nets。

  • 相关阅读:
    [Java] Hibernate
    python基础(十三):函数(一)公共操作
    python基础(十二):数据结构(五)集合
    python基础(十二):数据结构(四)字典
    python基础(十一):数据结构(三)元组
    python基础(十):数据结构(二)列表
    python基础(八):流程控制(五)循环
    python基础(七):流程控制(一)if
    python基础(六):运算符
    python基础(五):转换数据类型
  • 原文地址:https://www.cnblogs.com/byteHuang/p/7492633.html
Copyright © 2011-2022 走看看