zoukankan      html  css  js  c++  java
  • 关于卷积网络以及反卷积网络shape的计算

    CNN的计算方式:

    w1 = (w - F_w + 2p) / s_w + 1

    h1 = (h - F_h + 2p) / s_h + 1

    其中 w, h 分别为上一层的宽高, Filters(kernel)的大小为 F_w, F_h

    strides 步长为: s_w, s_h

    p 为padding 的大小

    DeCNN 的计算方式:

    w1 = (w -1 )* s_w + F_w - 2p

    h1 = (h -1 )* s_h + F_h - 2p

     

    其中 w, h 分别为上一层的宽高, Filters(kernel)的大小为 F_w, F_h

    strides 步长为: s_w, s_h

    p 为padding 的大小

    上面的有点错误, 看了tensorflow的实现:

    具体代码如下:

    def _compute_output_shape(self, input_shape):
      input_shape = tensor_shape.TensorShape(input_shape).as_list()
      output_shape = list(input_shape)
      if self.data_format == 'channels_first':
        c_axis, h_axis, w_axis = 1, 2, 3
      else:
        c_axis, h_axis, w_axis = 3, 1, 2
    
      kernel_h, kernel_w = self.kernel_size
      stride_h, stride_w = self.strides
    
      output_shape[c_axis] = self.filters
      output_shape[h_axis] = utils.deconv_output_length(
          output_shape[h_axis], kernel_h, self.padding, stride_h)
      output_shape[w_axis] = utils.deconv_output_length(
          output_shape[w_axis], kernel_w, self.padding, stride_w)
      return tensor_shape.TensorShape(output_shape)
    

    这里就是说, W, H的计算方式,有额外的utils包来辅助完成,具体的代码如下:

    def deconv_output_length(input_length, filter_size, padding, stride):
      """Determines output length of a transposed convolution given input length.
    
      Arguments:
          input_length: integer.
          filter_size: integer.
          padding: one of "same", "valid", "full".
          stride: integer.
    
      Returns:
          The output length (integer).
      """
      if input_length is None:
        return None
      input_length *= stride
      if padding == 'valid':
        input_length += max(filter_size - stride, 0)
      elif padding == 'full':
        input_length -= (stride + filter_size - 2)
      return input_length

    也就是说,分了三种padding的情况, “same”、"valid"、"full"三种方式,而每一种方式都不同。代码上给了后两者的实现。

    这说明,如果padding使用的是“same”的情况的话。input_lenght = input_lenght * 2。

    所以,DeCNN的输出计算分为三种方式,做如下总结:

    “same”:

    input_length *= stride

    "valid":

    input_length = input_length * stride + max(filter_size - stride, 0)

    "full":

    input_length = input_length * stride - stride + filter_size - 2 = (input_lenght - 1) * stride + filter_size - 2

    这里Filter_sieze为卷积核的大小,及kernel_size
  • 相关阅读:
    C#关于MSMQ通过HTTP远程发送专有队列消息的问题
    ASP.NET中进行消息处理(MSMQ) 三
    ASP.NET中进行消息处理(MSMQ) 二
    ASP.NET中进行消息处理(MSMQ) 一
    日志插件 log4net 的使用
    在64位windows下使用instsrv.exe和srvany.exe创建windows服务
    Windows下MemCache多端口安装配置
    把页面上DIV元素生成图片
    memcached协议
    没钱买珍珠首饰,能够画一个
  • 原文地址:https://www.cnblogs.com/flyu6/p/8417083.html
Copyright © 2011-2022 走看看