zoukankan      html  css  js  c++  java
  • PSPnet模型结构及实现代码

    
    

    该模块融合了4种不同金字塔尺度的特征,第一行红色是最粗糙的特征–全局池化生成单个bin输出,后面三行是不同尺度的池化特征。

    为了保证全局特征的权重,如果金字塔共有N个级别,则在每个级别后使用1×1 1×11×1的卷积将对于级别通道降为原本的1/N。再通过双线性插值获得未池化前的大小,最终concat到一起。












    1
    import torch 2 import torch.nn.functional as F 3 from torch import nn 4 from torchvision import models 5 6 from utils import initialize_weights 7 from utils.misc import Conv2dDeformable 8 from .config import res101_path 9 10 //金字塔模块,将从前面卷积结构提取的特征分别进行不同的池化操作,得到不同感受野以及全局语境信息(或者叫做不同层级的信息) 11 class _PyramidPoolingModule(nn.Module): 12 def __init__(self, in_dim, reduction_dim, setting): 13 super(_PyramidPoolingModule, self).__init__() 14 self.features = [] 15 for s in setting: //对应不同的池化操作,单个bin,多个bin 16 self.features.append(nn.Sequential( 17 nn.AdaptiveAvgPool2d(s), 18 nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 19 nn.BatchNorm2d(reduction_dim, momentum=.95), 20 nn.ReLU(inplace=True) 21 )) 22 self.features = nn.ModuleList(self.features) 23 24 def forward(self, x): 25 x_size = x.size() 26 out = [x] 27 for f in self.features: 28 out.append(F.upsample(f(x), x_size[2:], mode='bilinear')) 29 out = torch.cat(out, 1) 30 return out 31 32 //整个pspnet网络的结构 33 class PSPNet(nn.Module): 34 def __init__(self, num_classes, pretrained=True, use_aux=True): 35 super(PSPNet, self).__init__() 36 self.use_aux = use_aux 37 resnet = models.resnet101() //采用resnet101作为骨干模型,提取特征 38 if pretrained: 39 resnet.load_state_dict(torch.load(res101_path)) 40 self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 41 self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 42     //设置带洞卷积的参数(dilation),以及卷积的参数 43 for n, m in self.layer3.named_modules(): 44 if 'conv2' in n: 45 m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 46 elif 'downsample.0' in n: 47 m.stride = (1, 1) 48 for n, m in self.layer4.named_modules(): 49 if 'conv2' in n: 50 m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 51 elif 'downsample.0' in n: 52 m.stride = (1, 1) 53     //加入ppm模块,以及最后的连接层(卷积) 54 self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6)) 55 self.final = nn.Sequential( 56 nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), 57 nn.BatchNorm2d(512, momentum=.95), 58 nn.ReLU(inplace=True), 59 nn.Dropout(0.1), 60 nn.Conv2d(512, num_classes, kernel_size=1) 61 ) 62 63 if use_aux: 64 self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1) 65 initialize_weights(self.aux_logits) 66 # 初始化权重 67 initialize_weights(self.ppm, self.final) 68 69 def forward(self, x): 70 x_size = x.size() 71 x = self.layer0(x) 72 x = self.layer1(x) 73 x = self.layer2(x) 74 x = self.layer3(x) 75 if self.training and self.use_aux: 76 aux = self.aux_logits(x) 77 x = self.layer4(x) 78 x = self.ppm(x) 79 x = self.final(x) 80 if self.training and self.use_aux: 81 return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear') 82 return F.upsample(x, x_size[2:], mode='bilinear')
  • 相关阅读:
    再谈加密-RSA非对称加密的理解和使用
    WEB开发中的字符集和编码
    网页实时聊天之PHP实现websocket
    PHP中的回调函数和匿名函数
    shell实现SSH自动登陆
    初探PHP多进程
    PHP的openssl加密扩展使用小结
    搭建自己的PHP框架心得(三)
    docker 快速搭建Nexus3
    用图形数据库Neo4j 设计权限模块
  • 原文地址:https://www.cnblogs.com/ywheunji/p/10704237.html
Copyright © 2011-2022 走看看