zoukankan      html  css  js  c++  java
  • 利用pytorch复现spatial pyramid pooling层

    sppnet不讲了,懒得写。。。直接上代码

     1 from math import floor, ceil
     2 import torch
     3 import torch.nn as nn
     4 import torch.nn.functional as F
     5 
     6 class SpatialPyramidPooling2d(nn.Module):
     7     r"""apply spatial pyramid pooling over a 4d input(a mini-batch of 2d inputs 
     8     with additional channel dimension) as described in the paper
     9     'Spatial Pyramid Pooling in deep convolutional Networks for visual recognition'
    10     Args:
    11         num_level:
    12         pool_type: max_pool, avg_pool, Default:max_pool
    13     By the way, the target output size is num_grid:
    14         num_grid = 0
    15         for i in range num_level:
    16             num_grid += (i + 1) * (i + 1)
    17         num_grid = num_grid * channels # channels is the channel dimension of input data
    18     examples:
    19         >>> input = torch.randn((1,3,32,32), dtype=torch.float32)
    20         >>> net = torch.nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1),
    21                                       nn.ReLU(),
    22                                       SpatialPyramidPooling2d(num_level=2,pool_type='avg_pool'),
    23                                       nn.Linear(32 * (1*1 + 2*2), 10))
    24         >>> output = net(input)
    25     """
    26     
    27     def __init__(self, num_level, pool_type='max_pool'):
    28         super(SpatialPyramidPooling2d, self).__init__()
    29         self.num_level = num_level
    30         self.pool_type = pool_type
    31 
    32     def forward(self, x):
    33         N, C, H, W = x.size()
    34         for i in range(self.num_level):
    35             level = i + 1
    36             kernel_size = (ceil(H / level), ceil(W / level))
    37             stride = (ceil(H / level), ceil(W / level))
    38             padding = (floor((kernel_size[0] * level - H + 1) / 2), floor((kernel_size[1] * level - W + 1) / 2))
    39 
    40             if self.pool_type == 'max_pool':
    41                 tensor = (F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
    42             else:
    43                 tensor = (F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
    44             
    45             if i == 0:
    46                 res = tensor
    47             else:
    48                 res = torch.cat((res, tensor), 1)
    49         return res
    50     def __repr__(self):
    51         return self.__class__.__name__ + '(' 
    52             + 'num_level = ' + str(self.num_level) 
    53             + ', pool_type = ' + str(self.pool_type) + ')'
    54     
    55 
    56 class SPPNet(nn.Module):
    57     def __init__(self, num_level=3, pool_type='max_pool'):
    58         super(SPPNet,self).__init__()
    59         self.num_level = num_level
    60         self.pool_type = pool_type
    61         self.feature = nn.Sequential(nn.Conv2d(3,64,3),
    62                                     nn.ReLU(),
    63                                     nn.MaxPool2d(2),
    64                                     nn.Conv2d(64,64,3),
    65                                     nn.ReLU())
    66         self.num_grid = self._cal_num_grids(num_level)
    67         self.spp_layer = SpatialPyramidPooling2d(num_level)
    68         self.linear = nn.Sequential(nn.Linear(self.num_grid * 64, 512),
    69                                     nn.Linear(512, 10))
    70     def _cal_num_grids(self, level):
    71         count = 0
    72         for i in range(level):
    73             count += (i + 1) * (i + 1)
    74         return count
    75 
    76     def forward(self, x):
    77         x = self.feature(x)
    78         x = self.spp_layer(x)
    79         print(x.size())
    80         x = self.linear(x)
    81         return x
    82 
    83 if __name__ == '__main__':
    84     a = torch.rand((1,3,64,64))
    85     net = SPPNet()
    86     output = net(a)
    87     print(output)
  • 相关阅读:
    nodejs实战的github地址,喜欢的你还等啥
    java初始化深度剖析
    第三篇之消息的收发
    第二篇之收发消息的封装
    微信公众号开发第一篇之基本开发环境的搭建
    微信开发调试工具
    微信公众号开发入门教程第一篇
    linux常见驱动修改
    微信硬件开发步骤
    Linux系统快速启动方案
  • 原文地址:https://www.cnblogs.com/qinduanyinghua/p/9016235.html
Copyright © 2011-2022 走看看