zoukankan      html  css  js  c++  java
  • Caffe中im2col的实现解析

    这里,我是将Caffe中im2col的解析过程直接拉了出来,使用C++进行了输出,方便理解。代码如下:

      1 #include<iostream>
      2 
      3 using namespace std;
      4 
      5 bool is_a_ge_zero_and_a_lt_b(int a,int b)
      6 {
      7     if(a>=0 && a <b) return true;
      8     return false;
      9 }
     10 
     11 void im2col_cpu(const float* data_im, const int channels,
     12     const int height, const int width, const int kernel_h, const int kernel_w,
     13     const int pad_h, const int pad_w,
     14     const int stride_h, const int stride_w,
     15     const int dilation_h, const int dilation_w,
     16     float* data_col) {
     17   const int output_h = (height + 2 * pad_h -
     18     (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
     19   const int output_w = (width + 2 * pad_w -
     20     (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
     21   const int channel_size = height * width;
     22   for (int channel = channels; channel--; data_im += channel_size) {
     23     for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
     24       for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
     25         int input_row = -pad_h + kernel_row * dilation_h;
     26         for (int output_rows = output_h; output_rows; output_rows--) {
     27           if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
     28             for (int output_cols = output_w; output_cols; output_cols--) {
     29               *(data_col++) = 0;
     30             }
     31           } else {
     32             int input_col = -pad_w + kernel_col * dilation_w;
     33             for (int output_col = output_w; output_col; output_col--) {
     34               if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
     35                 *(data_col++) = data_im[input_row * width + input_col];
     36               } else {
     37                 *(data_col++) = 0;
     38               }
     39               input_col += stride_w;
     40             }
     41           }
     42           input_row += stride_h;
     43         }
     44       }
     45     }
     46   }
     47 }
     48 
     49 
     50 int main()
     51 {
     52      float* data_im;
     53     int height=5;
     54     int width=5;   
     55     int kernel_h=3;   
     56     int kernel_w=3;
     57     int pad_h=1;   
     58     int pad_w=1;
     59     int stride_h=1;   
     60     int stride_w=1;
     61     int dilation_h=1;   
     62     int dilation_w=1;
     63     float* data_col;
     64     int channels =3;
     65     const int output_h = (height + 2 * pad_h -
     66     (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
     67       const int output_w = (width + 2 * pad_w -
     68     (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
     69     data_im = new float[channels*height*width];
     70     data_col = new float[channels*output_h*output_w*kernel_h*kernel_w];
     71     
     72     //init input image data
     73     for(int m=0;m<channels;++m)
     74     {
     75       for(int i=0;i<height;++i)
     76       {
     77         for(int j=0;j<width;++j)
     78         {
     79           data_im[m*width*height+i*width+j] = m*width*height+ i*width +j;
     80           cout <<data_im[m*width*height+i*width+j] <<' ';
     81         }
     82         cout <<endl;
     83       }
     84     }
     85     
     86     im2col_cpu(data_im, channels,
     87      height,width, kernel_h, kernel_w,
     88     pad_h, pad_w,
     89     stride_h, stride_w,
     90     dilation_h, dilation_w,
     91      data_col);
     92     cout <<channels<<endl;
     93     cout <<output_h<<endl;
     94     cout <<output_w<<endl;
     95     cout <<kernel_h<<endl;
     96     cout <<kernel_w<<endl;
     97    // cout <<"error"<<endl;
     98     for(int i=0;i<kernel_w*kernel_h*channels;++i)
     99     {    
    100         for(int j=0;j<output_w*output_h;++j)
    101         {
    102             cout <<data_col[i*output_w*output_h+j]<<' ';
    103         }
    104         cout <<endl;
    105     }
    106 
    107     return 0;
    108 }

    多通道卷积的图像别人已经给过很多了,大家可以搜到的基本都来自于一篇。这里附上一个我自己的理解过程,和程序的输出是完全一致的

  • 相关阅读:
    Design Pattern Explained
    StringBuilder or StringBuffer
    Algorithms
    Difference between pages and blocks
    Date Time Calendar
    Math if fun
    Sublime Text
    Java Regex
    Learning C
    跨域通信/跨域上传浅析
  • 原文地址:https://www.cnblogs.com/jourluohua/p/9735897.html
Copyright © 2011-2022 走看看