zoukankan      html  css  js  c++  java
  • Caffe学习系列(15):添加新层

    如何在Caffe中增加一层新的Layer呢?主要分为四步:

    (1)在./src/caffe/proto/caffe.proto 中增加对应layer的paramter message;

    (2)在./include/caffe/***layers.hpp中增加该layer的类的声明,***表示有common_layers.hpp,

    data_layers.hpp, neuron_layers.hpp, vision_layers.hpp 和loss_layers.hpp等;

    (3)在./src/caffe/layers/目录下新建.cpp和.cu(GPU)文件,进行类实现。

    (4)在./src/caffe/gtest/中增加layer的测试代码,对所写的layer前传和反传进行测试,测试还包括速度。(可省略,但建议加上)

      

      这位博主添加了一个计算梯度的网络层,简介明了:

      http://blog.csdn.net/shuzfan/article/details/51322976

     

      这几位博主增加了自定义的loss层,可供参考:

      http://blog.csdn.net/langb2014/article/details/50489305

      http://blog.csdn.net/tangwei2014/article/details/46815231

     我以添加precision_recall_loss层来学习代码,主要是precision_recall_loss_layer.cpp的实现

    #include <algorithm>  
    #include <cfloat>  
    #include <cmath>  
    #include <vector>  
    #include <opencv2/opencv.hpp>  
      
    #include "caffe/layer.hpp"  
    #include "caffe/util/io.hpp"  
    #include "caffe/util/math_functions.hpp"  
    #include "caffe/vision_layers.hpp"  
      
    namespace caffe {  
      
    //初始化,调用父类进行相应的初始化
    template <typename Dtype>  
    void PrecisionRecallLossLayer<Dtype>::LayerSetUp(  
      const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {  
      LossLayer<Dtype>::LayerSetUp(bottom, top);  
    }  
    //进行维度变换
    template <typename Dtype>  
    void PrecisionRecallLossLayer<Dtype>::Reshape(  
      const vector<Blob<Dtype>*> &bottom,  
      const vector<Blob<Dtype>*> &top) {  
      //同样先调用父类的Reshape,通过成员变量loss_来改变输入维度
      LossLayer<Dtype>::Reshape(bottom, top);  
      loss_.Reshape(bottom[0]->num(), bottom[0]->channels(),  
                    bottom[0]->height(), bottom[0]->width());  
      
      // Check the shapes of data and label  检查两个输入的维度是否想等
      CHECK_EQ(bottom[0]->num(), bottom[1]->num())  
          << "The number of num of data and label should be same.";  
      CHECK_EQ(bottom[0]->channels(), bottom[1]->channels())  
          << "The number of channels of data and label should be same.";  
      CHECK_EQ(bottom[0]->height(), bottom[1]->height())  
          << "The heights of data and label should be same.";  
      CHECK_EQ(bottom[0]->width(), bottom[1]->width())  
          << "The width of data and label should be same.";  
    }  
    //前向传导 template
    <typename Dtype> void PrecisionRecallLossLayer<Dtype>::Forward_cpu( const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) { const Dtype *data = bottom[0]->cpu_data(); const Dtype *label = bottom[1]->cpu_data();
    const int num = bottom[0]->num(); //num和count什么区别 const int dim = bottom[0]->count() / num; const int channels = bottom[0]->channels(); const int spatial_dim = bottom[0]->height() * bottom[0]->width();
    //存疑?
    const int pnum = this->layer_param_.precision_recall_loss_param().point_num(); top[0]->mutable_cpu_data()[0] = 0;
    //对于每个通道
    for (int c = 0; c < channels; ++c) { Dtype breakeven = 0.0; Dtype prec_diff = 1.0; for (int p = 0; p <= pnum; ++p) { int true_positive = 0; //统计每类的个数 int false_positive = 0; int false_negative = 0; int true_negative = 0;
    for (int i = 0; i < num; ++i) { const Dtype thresh = 1.0 / pnum * p; //计算阈值? for (int j = 0; j < spatial_dim; ++j) {
    //取得相应的值和标签
    const Dtype data_value = data[i * dim + c * spatial_dim + j]; const int label_value = (int)label[i * dim + c * spatial_dim + j];
    //统计
    if (label_value == 1 && data_value >= thresh) { ++true_positive; } if (label_value == 0 && data_value >= thresh) { ++false_positive; } if (label_value == 1 && data_value < thresh) { ++false_negative; } if (label_value == 0 && data_value < thresh) { ++true_negative; } } }
    //计算precision和recall Dtype precision
    = 0.0; Dtype recall = 0.0; if (true_positive + false_positive > 0) { precision = (Dtype)true_positive / (Dtype)(true_positive + false_positive); } else if (true_positive == 0) { //都是负类? precision = 1.0; } if (true_positive + false_negative > 0) { recall = (Dtype)true_positive / (Dtype)(true_positive + false_negative); } else if (true_positive == 0) { recall = 1.0; } if (prec_diff > fabs(precision - recall) //如果二c者相差小 && precision > 0 && precision < 1 && recall > 0 && recall < 1) { breakeven = precision; //保留 prec_diff = fabs(precision - recall); } } top[0]->mutable_cpu_data()[0] += 1.0 - breakeven; //计算误差 } top[0]->mutable_cpu_data()[0] /= channels; //??? } //反向 template <typename Dtype> void PrecisionRecallLossLayer<Dtype>::Backward_cpu( const vector<Blob<Dtype>*> &top, const vector<bool> &propagate_down, const vector<Blob<Dtype>*> &bottom) { for (int i = 0; i < propagate_down.size(); ++i) { if (propagate_down[i]) { NOT_IMPLEMENTED; } } } #ifdef CPU_ONLY STUB_GPU(PrecisionRecallLossLayer); #endif //注册该层 INSTANTIATE_CLASS(PrecisionRecallLossLayer); REGISTER_LAYER_CLASS(PrecisionRecallLoss); } // namespace caffe
    1. template <typename Dtype>  
    2. void PrecisionRecallLossLayer<Dtype>::Forward_cpu(  
    3.   const vector<Blob<Dtype>*> &bottom, const vector<Blob<Dtype>*> &top) {  
    4.   const Dtype *data = bottom[0]->cpu_data();  
    5.   const Dtype *label = bottom[1]->cpu_data();  
    6.   const int num = bottom[0]->num();  
    7.   const int dim = bottom[0]->count() / num;  
    8.   const int channels = bottom[0]->channels();  
    9.   const int spatial_dim = bottom[0]->height() * bottom[0]->width();  
    10.   const int pnum =  
    11.     this->layer_param_.precision_recall_loss_param().point_num();  
    12.   top[0]->mutable_cpu_data()[0] = 0;  
    13.   for (int c = 0; c < channels; ++c) {  
    14.     Dtype breakeven = 0.0;  
    15.     Dtype prec_diff = 1.0;  
    16.     for (int p = 0; p <= pnum; ++p) {  
    17.       int true_positive = 0;  
    18.       int false_positive = 0;  
    19.       int false_negative = 0;  
    20.       int true_negative = 0;  
    21.       for (int i = 0; i < num; ++i) {  
    22.         const Dtype thresh = 1.0 / pnum * p;  
    23.         for (int j = 0; j < spatial_dim; ++j) {  
    24.           const Dtype data_value = data[i * dim + c * spatial_dim + j];  
    25.           const int label_value = (int)label[i * dim + c * spatial_dim + j];  
    26.           if (label_value == 1 && data_value >= thresh) {  
    27.             ++true_positive;  
    28.           }  
    29.           if (label_value == 0 && data_value >= thresh) {  
    30.             ++false_positive;  
    31.           }  
    32.           if (label_value == 1 && data_value < thresh) {  
    33.             ++false_negative;  
    34.           }  
    35.           if (label_value == 0 && data_value < thresh) {  
    36.             ++true_negative;  
    37.           }  
    38.         }  
    39.       }  
    40.       Dtype precision = 0.0; 
  • 相关阅读:
    了解教育网访问情况
    .NET开源社区存在的问题
    欢迎大家谈谈Windows Live Mail desktop的使用感受
    [公告]新版排行榜页面发布
    国外技术新闻[来自Digg.com]
    首页小改进
    [SQL Server 2005]String or binary data would be truncated
    Linux获得真正3D桌面 开源支持者喜不自禁
    新版 .net开发必备10大工具
    使用新类型Nullable处理数据库表中null字段
  • 原文地址:https://www.cnblogs.com/573177885qq/p/6065625.html
Copyright © 2011-2022 走看看