zoukankan      html  css  js  c++  java
  • caffe卷积层代码阅读笔记

    卷积的实现思想:

    • 通过im2col将image转为一个matrix,将卷积操作转为矩阵乘法运算
    • 通过调用GEMM完毕运算操作
    • 以下两个图是我在知乎中发现的,“盗”用一下,确实非常好。能帮助理解。


      这里写图片描写叙述
      这里写图片描写叙述

    參数剖析

    • 配置參数:(从配置文件得来)
      kernel_h_ pad_h_ hole_h_ stride_h_
      kernel_w_ pad_w_ hole_w_ stride_w_
      is_1x1_:上面8个參数都为1时,该參数为true

    • 和输入有关的參数:(从bottom得来)
      num_
      channels_
      height_
      width_

    • 和卷积核有关的參数:(前两个參数从配置文件得来)
      num_output_
      group_
      this->blobs_[0].reset(new Blob(num_output_, channels_ / group_, kernel_h_, kernel_w_));
      this->blobs_[1].reset(new Blob(1, 1, 1, num_output_));
      this->param_propagate_down_

    • 和输出有关的參数:(计算得来)
      const int kernel_h_eff = kernel_h_ + (kernel_h_ - 1) * (hole_h_ - 1);
      const int kernel_w_eff = kernel_w_ + (kernel_w_ - 1) * (hole_w_ - 1);
      height_out_ = (height_ + 2 * pad_h_ - kernel_h_eff) / stride_h_ + 1;
      width_out_ = (width_ + 2 * pad_w_ - kernel_w_eff) / stride_w_ + 1;

    • 和矩阵运算有关的參数:(计算得来)
      M_ = num_output_ / group_;
      K_ = channels_ * kernel_h_ * kernel_w_ / group_;
      N_ = height_out_ * width_out_;
      col_buffer_.Reshape(1, channels_*kernel_h_*kernel_w_, height_out_, width_out_);// is_1x1_为false的时候用
      bias_multiplier_.Reshape(1, 1, 1, N_); //所有为1

    输入大小:(num_, channels_, height_, width_)
    输出大小:(num_, num_output_, height_out_, width_out_)

    重点函数剖析

    • 函数一:
      im2col_cpu(bottom_data + bottom[i]->offset(n),
      1, channels_, height_, width_,
      kernel_h_, kernel_w_, pad_h_, pad_w_,
      stride_h_, stride_w_, hole_h_, hole_w_,
      col_buff);

      该函数的目的是:依据配置參数,将一幅(1, channels_, height_, width_)的输入feature map expand成 (1, channels_*kernel_h_*kernel_w_, height_out_, width_out_)大小的矩阵。

      详细的实现方法是:
      内部主要有两套索引
      一套是在输入图像上的索引,各自是:c_im(channels), h_im(height), w_im(width)
      还有一套是在输出的col_buff上的。各自是:c(channels_col), h(height_col), w(width_col)

      循环变量来自输出的col_buff的维数,依据输出的位置计算相应在输入图像上的位置(col2imh函数和im2col函数是一个道理。两套坐标反着来即可)。把索引的代码整合出来。对着源码看。非常easy懂:

        const int kernel_h_eff = kernel_h + (kernel_h - 1) * (hole_h - 1);
        const int kernel_w_eff = kernel_w + (kernel_w - 1) * (hole_w - 1);
        int height_col = (height + 2 * pad_h - kernel_h_eff) / stride_h + 1;
        int width_col = (width + 2 * pad_w - kernel_w_eff) / stride_w + 1;
        int channels_col = channels * kernel_h * kernel_w;
        int w_offset = (c % kernel_w)  * hole_w;
        int h_offset = ((c / kernel_w) % kernel_h) * hole_h;
        int c_im = c / kernel_w / kernel_h;
        const int h_im = h * stride_h + h_offset - pad_h;
        const int w_im = w * stride_w + w_offset - pad_w;
    • 函数二:

      caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_,
      (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g,
      (Dtype)0., top_data + top[i]->offset(n) + top_offset * g);

      该函数的目的是:
      将(num_output_/group_, channels_ /group_, kernel_h_, kernel_w_)卷积核看成一个(num_output_/group_, channels_*kernel_h_*kernel_w_/group_)的矩阵A,即大小为M_x K_。

      将(1, channels_*kernel_h_*kernel_w_, height_out_, width_out_)的col_buff看成group_个(channels_*kernel_h_*kernel_w_/group_, height_out_*width_out_)的矩阵B。即大小为K_x N_。

      两者相乘再加上偏置项。就能得到卷积的结果。

      解释caffe_cpu_gemm函数:
      事实上其内部包了一个cblas_sgemm函数。


      void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
      const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
      const int K, const float alpha, const float *A,
      const int lda, const float *B, const int ldb,
      const float beta, float *C, const int ldc)

      得到的结果是:
      C = alpha*op( A )*op( B ) + beta*C

      const enum CBLAS_ORDER Order,这是指的数据的存储形式,在CBLAS的函数中不管一维还是二维数据都是用一维数组存储,这就要涉及是行主序还是列主序。在C语言中数组是用 行主序。fortran中是列主序。

      假设是习惯于是用行主序,所以这个參数是用CblasRowMajor。假设是列主序的话就是 CblasColMajor。


      const int M,矩阵A的行,矩阵C的行
      const int N,矩阵B的列。矩阵C的列
      const int K,矩阵A的列。矩阵B的行

  • 相关阅读:
    spring boot的application配置文件
    C# WinForm 中英文实现, 国际化实现的简单方法
    VS2012 2013 显示查找功能 无法具体定位 解决方法
    C#使用HttpWebRequest 进行请求,提示 基础连接已经关闭: 发送时发生错误。
    VS 默认开发环境如何更改
    C# winfrom HttpWebRequest 请求获取html网页信息和提交信息
    C# 定时器 Timers.Timer Forms.Timer
    HTTP 错误 500.21
    配置iis时,浏览项目提示 无法识别的属性“targetFramework”。请注意属性名称区分大小写。
    asp xml对象转换为string
  • 原文地址:https://www.cnblogs.com/mfmdaoyou/p/7257335.html
Copyright © 2011-2022 走看看