zoukankan      html  css  js  c++  java
  • pytorch实现squeezenet

    squeezenet是16年发布的一款轻量级网络模型,模型很小,只有4.8M,可用于移动设备,嵌入式设备。

    关于squeezenet的原理可自行阅读论文或查找博客,这里主要解读下pytorch对squeezenet的官方实现。

    地址:https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py

    首先定义fire模块,这是squeezenet的核心所在,降低3X3卷积的数量。

    class Fire(nn.Module):
    
        def __init__(self, inplanes, squeeze_planes,
                     expand1x1_planes, expand3x3_planes):
            super(Fire, self).__init__()
            self.inplanes = inplanes
            self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)#定义压缩层,1X1卷积
            self.squeeze_activation = nn.ReLU(inplace=True)
            self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,#定义扩展层,1X1卷积
                                       kernel_size=1)
            self.expand1x1_activation = nn.ReLU(inplace=True)
            self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,#定义扩展层,3X3卷积
                                       kernel_size=3, padding=1)
            self.expand3x3_activation = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x = self.squeeze_activation(self.squeeze(x))
            return torch.cat([
                self.expand1x1_activation(self.expand1x1(x)),
                self.expand3x3_activation(self.expand3x3(x))
            ], 1)

    可以看到首先定义压缩层与两个扩展层,压缩层用的是1X1卷积,扩展层是1X1卷积和3X3卷积的混合使用,网络inference的脉络是先经过压缩层,然后并行经过两个扩展层,最后将扩展层串联。

    定义完核心模块,来看网络整体。

    class SqueezeNet(nn.Module):
    
        def __init__(self, version=1.0, num_classes=1000):
            super(SqueezeNet, self).__init__()
            if version not in [1.0, 1.1]:
                raise ValueError("Unsupported SqueezeNet version {version}:"
                                 "1.0 or 1.1 expected".format(version=version))
            self.num_classes = num_classes
            if version == 1.0:
                self.features = nn.Sequential(
                    nn.Conv2d(3, 96, kernel_size=7, stride=2),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(96, 16, 64, 64),
                    Fire(128, 16, 64, 64),
                    Fire(128, 32, 128, 128),
                    nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(256, 32, 128, 128),
                    Fire(256, 48, 192, 192),
                    Fire(384, 48, 192, 192),
                    Fire(384, 64, 256, 256),
                    nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(512, 64, 256, 256),
                )
            else:
                self.features = nn.Sequential(
                    nn.Conv2d(3, 64, kernel_size=3, stride=2),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(64, 16, 64, 64),
                    Fire(128, 16, 64, 64),
                    nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(128, 32, 128, 128),
                    Fire(256, 32, 128, 128),
                    nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                    Fire(256, 48, 192, 192),
                    Fire(384, 48, 192, 192),
                    Fire(384, 64, 256, 256),
                    Fire(512, 64, 256, 256),
                )
            # Final convolution is initialized differently form the rest
            final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
            self.classifier = nn.Sequential(
                nn.Dropout(p=0.5),
                final_conv,
                nn.ReLU(inplace=True),
                nn.AvgPool2d(13, stride=1)
            )
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    if m is final_conv:
                        init.normal_(m.weight, mean=0.0, std=0.01)
                    else:
                        init.kaiming_uniform_(m.weight)
                    if m.bias is not None:
                        init.constant_(m.bias, 0)
    
        def forward(self, x):
            x = self.features(x)
            x = self.classifier(x)
            return x.view(x.size(0), self.num_classes)

    首先依然是定义网络层,在这里有两个版本,差别不大,都是fire模块的堆积,最后经过全局平均池化输出1000类。这里对卷积层采用了不同的初始化策略,我还没仔细研究过,就不说了。

  • 相关阅读:
    Mybatis-plus学习笔记(一)
    Mysql基础(四)分组查询及连接查询
    Mysql 基础(三)排序查询及常用函数
    CyclicBarrier 使用详解
    countDownLatch
    pom所有依赖version红色但是不影响运行
    iText5实现Java生成PDF文件完整版
    【Maven】---Nexus私服配置Setting和Pom
    引用、指针、const、define、static、sizeof、左值右值
    事物隔离级别、MVCC以及数据库中常见锁介绍
  • 原文地址:https://www.cnblogs.com/wzyuan/p/9710565.html
Copyright © 2011-2022 走看看