zoukankan      html  css  js  c++  java
  • PyTorch——模型搭建——MobileNetV1(一)

    MobileNetv1模型

    PyTorch自带的MobileNetV1

    没有实现MobileNetV1

    详情参考:https://pytorch.org/vision/stable/models.html

    自己搭建

     1 import torch
     2 import torch.nn as nn
     3 import torch.nn.functional as F
     4 
     5 class Block(nn.Module):
     6     "Depthwise conv + Pointwise conv"
     7     def __init__(self, in_channels, out_channels, stride=1):
     8         super(Block, self).__init__()
     9         self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
    10         self.bn1 = nn.BatchNorm2d(in_channels)
    11         self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
    12         self.bn2 = nn.BatchNorm2d(out_channels)
    13 
    14     def forward(self, x):
    15         x = self.conv1(x)
    16         x = self.bn1(x)
    17         x = F.relu(x)
    18         x = self.conv2(x)
    19         x = self.bn2(x)
    20         x = F.relu(x)
    21         return x
    22 
    23 class MobileNet(nn.Module):
    24 
    25     cfg = [64, (128, 2), 128, (256, 2), (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]
    26 
    27     def __init__(self, num_classes=10):
    28         super(MobileNet, self).__init__()
    29         self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
    30         self.bn1 = nn.BatchNorm2d(32)
    31         self.layers = self._make_layers(in_planes=32)
    32         self.linear = nn.Linear(1024, num_classes)
    33 
    34     def _make_layers(self, in_planes):
    35         layers = []
    36         for x in self.cfg:
    37             out_planes = x if isinstance(x, int) else x[0]
    38             stride = 1 if isinstance(x, int) else x[1]
    39             layers.append(Block(in_planes, out_planes, stride))
    40             in_planes = out_planes
    41         return nn.Sequential(*layers)
    42 
    43     def forward(self, x):
    44         out = F.relu(self.bn1(self.conv1(x)))
    45         out = self.layers(out)
    46         out = F.avg_pool2d(out, 7)
    47         out = out.view(out.size(0), -1)
    48         out = self.linear(out)
    49         return out
    50 
    51 input = torch.randn(32, 3, 224, 224)
    52 net = MobileNet(8)
    53 out = net(input)
    54 print(out.size())
  • 相关阅读:
    Oct 21st-
    ContextLoaderListener 解析
    HTTPS 证书制作及使用
    Spring MVC 源码分析
    思考
    《深入理解java虚拟机》 第七章虚拟机类加载机制
    《深入理解java虚拟机》第六章 类文件结构
    《深入理解java虚拟机》第三章 垃圾收集器与内存分配策略
    《深入理解java虚拟机》第二章 Java内存区域与内存溢出异常
    SSM-1第一章 认识SSM框架和Redis
  • 原文地址:https://www.cnblogs.com/timelesszxl/p/14566107.html
Copyright © 2011-2022 走看看