zoukankan      html  css  js  c++  java
  • RetinaNet pytorch implement from scratch 02--FPN



    import torch.nn as nn
    class FPN(nn.Module):
        def __init__(self, C3_size, C4_size, C5_size, feature_size=256):
            # feature_size 为feature map的channel数
            super(FPN, self).__init__()
            self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
            # upsampling
            self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
            self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
            self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
            self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
            self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
            self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
            self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
            # "P6 is obtained via a 3x3 stride-2 conv on C5"
            # if p6 use the P5 ,just change the argu C5_size to 256 !!!
            # self.P6 = nn.Conv2d(256, feature_size, kernel_size=3, stride=2, padding=1)
            self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)
            self.P7_1 = nn.ReLU()
            # =upsample
            self.P7_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)
        def forward(self, x):
            C3, C4, C5 = x
            P5_x = self.P5_1(C5)
            P5_upsampled_x = self.P5_upsampled(P5_x)
            P5_x = self.P5_2(P5_x)
            P4_x = self.P4_1(C4)
            P4_x = P5_upsampled_x + P4_x
            P4_upsampled_x = self.P4_upsampled(P4_x)
            P4_x = self.P4_2(P4_x)
            P3_x = self.P3_1(C3)
            P3_x = P3_x + P4_upsampled_x
            P3_x = self.P3_2(P3_x)
            P6_x = self.P6(C5)
            #P6_x = self.P6(P5_x)
            P7_x = self.P7_1(P6_x)
            P7_x = self.P7_2(P7_x)
            return [P3_x, P4_x, P5_x, P6_x, P7_x]
  • 相关阅读:
    Activiti系列——如何在eclipse中安装 Activiti Designer插件
    C语言 二维数组与指针笔记
    Ubuntu linux设置从当前目录下加载动态库so文件
    Ubuntu14.04 搭建FTP服务器
    swiper 窗口宽度变化,页面宽度高度变化 导致自动滑动 解决方案
  • 原文地址:https://www.cnblogs.com/Valeyw/p/15019676.html
Copyright © 2011-2022 走看看