zoukankan      html  css  js  c++  java
  • pytorch-googleNet

    googleNet网络结构 

     输入网络: 由4个分支网络构成

    第一分支: 由1x1的卷积构成

    第二分支: 由1x1的卷积,3x3的卷积构成

    第三分支: 由1x1的卷积, 5x5的卷积构成

    第四分支: 由3x3的最大值池化, 1x1的卷积构成 

    import torch
    from torch import nn
    from torch.nn import functional as F
    
    class BasicConv2d(nn.Module):
        def __init__(self, in_channels, out_channels, **kwargs):
            super(BasicConv2d, self).__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) # 构造卷积层
            self.bn = nn.BatchNorm2d(out_channels, eps=0.001) # 构造标准化
    
        def forward(self, x):
            x = self.conv(x) # 进行卷积操作
            x = self.bn(x) # 进行标准化操作
            x = F.relu(x) # 进行激活层操作
    
            return x
    
    class Inception(nn.Module):
        def __init__(self, in_channels, pool_features):
            super(Inception, self).__init__()
            self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) # 1x1的卷积操作
    
            self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) # 进行卷积操作
            self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 
    
            self.branch3x3db1_1 = BasicConv2d(in_channels, 64, kernel_size=1)
            self.branch3x3db1_2 = BasicConv2d(64, 96, kernel_size=3)
            self.branch3x3db1_3 = BasicConv2d(96, 96, kernel_size=3)
    
            self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
    
        def forward(self, x):
            branch1x1 = self.branch1x1(x)
    
            branch5x5 = self.branch5x5_1(x)
            branch5x5 = self.branch5x5_2(branch5x5)
    
            branch3x3db1_1 = self.branch3x3db1_1(x)
            branch3x3db1_2 = self.branch3x3db1_2(branch3x3db1_1)
            branch3x3db1_3 = self.branch3x3db1_3(branch3x3db1_2)
    
            branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
            branch_pool = self.branch_pool(branch_pool)
            # 进行卷积的叠加操作
            outputs = [branch1x1, branch5x5, branch3x3db1_3, branch_pool]
            outputs = torch.cat(outputs, dim=1)
    
            return outputs
  • 相关阅读:
    Redis与Memcached的incr/decr差异对比
    Linux sudo用法与配置
    Docker 常用命令
    Linux之间配置SSH互信(SSH免密码登录)
    SVN服务器搭建
    shell中参数的传递
    【代码更新】IIC协议建模——读写EEPROM
    串口完整项目之串口收发字符串
    串口发送模块——1字节数据发送
    状态机设计——从简单的按键消抖开始
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/11732635.html
Copyright © 2011-2022 走看看