zoukankan      html  css  js  c++  java
  • [面试题]实现im2col

    题目

    实际上在一些深度学习框架的底层,当实现Conv2D运算时,是将Conv转化为im2col和GEMM来进行运算的(比如Caffe和MxNet),之前面试的时候就被问到怎么实现im2col。

    img2col是将img和kernel对应的那一块铺开成一行,然后将kernel铺成一列,两者进行矩阵乘法运算,这样可以减少内存搬运。

    代码

    下面就实现了im2col的py版本:

    # import os
    import numpy as np
    
    
    def im2col(input_data, ksize, stride=1, pad=0):
        N, C, H, W = input_data.shape
        out_h = (H + 2 * pad - ksize) // stride + 1
        out_w = (W + 2 * pad - ksize) // stride + 1
    
        img = np.pad(input_data, [(0, 0), (0, 0), (pad, pad), (pad, pad)], "constant")
        col = np.zeros((N, C, ksize, ksize, out_h, out_w))
    
        for y in range(ksize):
            y_max = y + stride * out_h
            for x in range(ksize):
                x_max = x + stride * out_w
                col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]     
        col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
        return col
    
    
    def main():
        A = np.arange(1, 49).reshape(3, 4, 4)
        input_img = A.reshape(1, 3, 4, 4)
        col = im2col(input_img, 3, stride=1, pad=1)
        print(col)
    
    
    if __name__ == "__main__":
        main()
    
    
  • 相关阅读:
    pch文件的创建
    常用的Xcode插件下载地址
    内存管理
    学习笔记-static的作用
    IOS 之label的自适应
    OC中的循环引用
    理解事务的4种隔离级别
    Solrcloud集群搭建
    常见前端浏览器兼容问题及解决方案
    Java内存溢出详解及配置
  • 原文地址:https://www.cnblogs.com/wildkid1024/p/13772833.html
Copyright © 2011-2022 走看看