zoukankan      html  css  js  c++  java
  • 卷积操作详解

    今天研究了一下卷积计算。
    卷积涉及到的两个输入为: 图像和filter

    • 图像: 维度为 C*H*W C是channel, 也叫做 depth, H和W就是图像的宽和高了。

    • filter, 维度为 K*K, 假设 filter的个数为 M个
      直接进行卷积的伪代码为

    for w in 1..W (img_width)
      for h in 1..H (img_height)
        for x in 1..K (filter_width)
          for y in 1..K (filter_height)
            for m in 1..M (num_filters)
              for c in 1..C (img_channel)
                output(w, h, m) += input(w+x, h+y, c) * filter(m, x, y, c)
              end
            end
          end
        end
      end
    end

    -使用矩阵进行卷积操作,计算量:

    卷积运算,输入M层,输出N层,核尺寸k。输入数据大小H*W。
    卷积参数数量:weight + bias = M*N*k*k+N
    卷积运算量:H*W*N*M^2*K^4 ??

    这里写图片描述 

    • 卷积就变成了矩阵乘法 (Gemm in BLAS) . BLAS 库有(MKL, Atlas, CuBLAS)。这个算法最近被打败:Alex Krizhevsky’s 在 cuda-convnet [2]的优化 code [3]:

    https://github.com/soumith/convnet-benchmarks

    • 卷积之后的输出的维度为 num_filter* out_h * out_w;注意out_h 和 out_w 是img_h, img_w经过pading 和stride之后的宽高。

    hout=himg+2PadingKfilterhS+1

    wout=wimg+2PadingKfilterwS+1


    这里写图片描述
    这个后续再看
    这里写图片描述

    ///////////////////////////////////////////////////////////////////////////////
    // Simplest 2D convolution routine. It is easy to understand how convolution
    // works, but is very slow, because of no optimization.
    ///////////////////////////////////////////////////////////////////////////////
    bool convolve2DSlow(unsigned char* in, unsigned char* out, int dataSizeX, int dataSizeY,
                        float* kernel, int kernelSizeX, int kernelSizeY)
    {
        int i, j, m, n, mm, nn;
        int kCenterX, kCenterY;                         // center index of kernel
        float sum;                                      // temp accumulation buffer
        int rowIndex, colIndex;
    
        // check validity of params
        if(!in || !out || !kernel) return false;
        if(dataSizeX <= 0 || kernelSizeX <= 0) return false;
    
        // find center position of kernel (half of kernel size)
        kCenterX = kernelSizeX / 2;
        kCenterY = kernelSizeY / 2;
    
        for(i=0; i < dataSizeY; ++i)                // rows
        {
            for(j=0; j < dataSizeX; ++j)            // columns
            {
                sum = 0;                            // init to 0 before sum
                for(m=0; m < kernelSizeY; ++m)      // kernel rows
                {
                    mm = kernelSizeY - 1 - m;       // row index of flipped kernel
    
                    for(n=0; n < kernelSizeX; ++n)  // kernel columns
                    {
                        nn = kernelSizeX - 1 - n;   // column index of flipped kernel
    
                        // index of input signal, used for checking boundary
                        rowIndex = i + m - kCenterY;
                        colIndex = j + n - kCenterX;
    
                        // ignore input samples which are out of bound
                        if(rowIndex >= 0 && rowIndex < dataSizeY && colIndex >= 0 && colIndex < dataSizeX)
                            sum += in[dataSizeX * rowIndex + colIndex] * kernel[kernelSizeX * mm + nn];
                    }
                }
                out[dataSizeX * i + j] = (unsigned char)((float)fabs(sum) + 0.5f);
            }
        }
    
        return true;
    }
  • 相关阅读:
    POST和GET的区别
    Java设计模式6大原则
    JAVA23种工厂模式
    使用jsp实现用户登录请求
    MVC模式
    使用idea查询数据库内容
    mysql常见错误
    定义外键和建表原则
    CSS制作圆角边框
    2、JS的编写位置
  • 原文地址:https://www.cnblogs.com/ranjiewen/p/8638038.html
Copyright © 2011-2022 走看看