zoukankan      html  css  js  c++  java
  • 使用matlab自带工具实现rcnn

    平台:matlab2016b

    matlab自带一个cifar10Net工具可用于深度学习。

    图片标注

    这里使用的是matlab自带的工具trainingImageLabeler对图像进行roi的标注。

    选择AddImages将要训练的图片放进去(可以放入多张图片),在ROI Label区域右键可以选择改变label 的color和name,如果要训练多个类,也可以点击Add ROI Label来添加label。

    所有图像标注完成后点击Export ROIs后会得到一个table(或stuct)变量,使用

    save(‘file’,‘variable’);
    

    命令来保存
    因为cifar10Net使用的是table,如果你的数据集使用的是stuct,
    这里使用

     data=struct2table(file);
    

    来将stuct转化为table

    imageFilename代表了图片所存储的位置;
    tire代表了图片中标注的轮胎,用矩阵存储,分别为roi左上的坐标(x,y)和roi的大小(width,height);

    RCNN训练

    我们来查看下网络结构

    load('rcnnStopSigns.mat','cifar10Net');
    cifar10Net.Layers
    

    会得到以下输出

    ans = 
    
    15x1 Layer array with layers:
    
     1   'imageinput'    Image Input             32x32x3 images with 'zerocenter' normalization
     2   'conv'          Convolution             32 5x5x3 convolutions with stride [1  1] and padding [2  2]
     3   'relu'          ReLU                    ReLU
     4   'maxpool'       Max Pooling             3x3 max pooling with stride [2  2] and padding [0  0]
     5   'conv_1'        Convolution             32 5x5x32 convolutions with stride [1  1] and padding [2  2]
     6   'relu_1'        ReLU                    ReLU
     7   'maxpool_1'     Max Pooling             3x3 max pooling with stride [2  2] and padding [0  0]
     8   'conv_2'        Convolution             64 5x5x32 convolutions with stride [1  1] and padding [2  2]
     9   'relu_2'        ReLU                    ReLU
    10   'maxpool_2'     Max Pooling             3x3 max pooling with stride [2  2] and padding [0  0]
    11   'fc'            Fully Connected         64 fully connected layer
    12   'relu_3'        ReLU                    ReLU
    13   'fc_1'          Fully Connected         10 fully connected layer
    14   'softmax'       Softmax                 softmax
    15   'classoutput'   Classification Output   cross-entropy with 'airplane', 'automobile', and 8 other classes
    

    通过观察可以看出,一共只有三个卷积层
    我们要对这个网络进行微调,因为我这里只训练了一个车轮,提供的数据中还包含有无标注的图片,所以全连接层的输出要改成2。后面再接上一个softmax层和一个classificationLayer,并且定义训练方式:

    x=cifar10Net.Layers(1:end-3);
    
    lastlayers = [
    fullyConnectedLayer(2,'Name','fc8','WeightLearnRateFactor',1, 'BiasLearnRateFactor',1)
    softmaxLayer('Name','softmax')
    classificationLayer('Name','classification')
    ];
    
    options = trainingOptions('sgdm', ...
     'MiniBatchSize', 32, ...
     'InitialLearnRate', 1e-6, ...
     'MaxEpochs', 100);
    

    RCNN的训练主要使用trainRCNNObjectDetector.m函数

     detector = trainRCNNObjectDetector(groundTruth,network,options)
    

    groundTruth - 具有2个或更多列的表。 第一列必须包含图像文件名。 图像可以是灰度或真彩色,可以是IMREAD支持的任何格式。 其余列必须包含指定每个图像内对象位置的[x,y,width,height]边框的M×4矩阵。 每列表示单个对象类,例如。 人,车,狗。 其实就是之前使用trainingImageLabeler做标注得到的数据。

    network - 即为CNN的网络结构

    options - 即为网络训练的参数。包括初始化学习率、迭代次数、BatchSize等等。

    除了以上三个参数外,还有

    ‘PositiveOverlapRange’ - 一个双元素向量,指定0和1之间的边界框重叠比例范围。与指定范围内(即之前做图片标注画出的框)的边界框重叠的区域提案被用作正训练样本。Default: [0.5 1]

    ‘NegativeOverlapRange’ - 一个双元素向量,指定0和1之间的边界框重叠比例范围。与指定范围内(即之前做图片标注画出的框)的边界框重叠的区域提案被用作负训练样本。Default: [0.1 0.5]

    在训练之前,RCNN会从训练图片中得到很多候选框,其中满足正样本要求的会被当做训练正样本,而满足负样本要求的会被当做训练负样本。

    ‘NumStrongestRegions’ - 用于生成训练样本的最强区域建议的最大数量(即最后得到的候选框数量)。 降低该值以加快处理时间,以训练准确性为代价。 将此设置为inf以使用所有区域提案。Default: 2000

    之后对训练完成的结果进行检测

    clear;
    tic;
    load myRCNN.mat;
    detectedImg = imread('cars_train_croped(227_227)8031.jpg');
    
    [bbox, score, label] = detect(myRCNN, detectedImg, 'MiniBatchSize', 20);
    
    imshow(detectedImg);
    
    idx=find(score>0.1);
    bbox = bbox(idx, :);
    n=size(idx,1);
    for i=1:n
        annotation = sprintf('%s: (Confidence = %f)', label(idx(i)), score(idx(i)));
        de = insertObjectAnnotation(detectedImg, 'rectangle', bbox(i,:), annotation);
    end
    
    figure
    imshow(de);
    toc;
    

    参考博客:https://blog.csdn.net/qq_33801763/article/details/77185457
    https://blog.csdn.net/mr_curry/article/details/53160914
    https://blog.csdn.net/u014096352/article/details/72854077

  • 相关阅读:
    错误: error C4996: 'strcpy': This function or variable may be unsafe. Consider using strcpy_s instead. 的处理方法
    C语言习题
    嵌入式芯片STM32F407
    c语言课后习题
    求方程式的根
    C语言课后习题
    LINUX常用指令
    在 pythonanywhere 上搭建 django 程序(Virtualenv+python2.7+django1.8)
    Git远程操作详解
    ./configure,make,make install的作用
  • 原文地址:https://www.cnblogs.com/passbyone/p/8883682.html
Copyright © 2011-2022 走看看