zoukankan      html  css  js  c++  java
  • PyTorch——模型搭建——VGG(一)

    模型搭建

    0、VGG模型

    1、torchvision自带的VGG模型

    1 import torch
    2 import torchvision
    3 from torchsummary import summary
    4 
    5 model = torchvision.models.vgg13(num_classes=8).cuda()
    6 summary(model, (3, 224, 224))

    2、自己搭建

     1 import torch
     2 import torch.nn as nn
     3 import torch.nn.functional as F
     4 from torchsummary import summary
     5 
     6 class VGG(nn.Module):
     7 
     8     def __init__(self, arch, num_classes=1000):
     9         super(VGG, self).__init__()
    10         self.in_channels = 3
    11         self.conv3_64 = self._make_layer(64, arch[0])
    12         self.conv3_128 = self._make_layer(128, arch[1])
    13         self.conv3_256 = self._make_layer(256, arch[2])
    14         self.conv3_512a = self._make_layer(512, arch[3])
    15         self.conv3_512b = self._make_layer(512, arch[4])
    16         self.flatten = nn.Flatten()
    17         self.fc1 = nn.Linear(7*7*512, 4096)
    18         self.bn1 = nn.BatchNorm1d(4096)
    19         self.fc2 = nn.Linear(4096, 4096)
    20         self.fc3 = nn.Linear(4096, num_classes)
    21 
    22     def _make_layer(self, channels, num):
    23         layers = []
    24         for i in range(num):
    25             layers.append(nn.Conv2d(self.in_channels, channels, 3, stride=1, padding=1, bias=False))
    26             layers.append(nn.BatchNorm2d(channels))
    27             layers.append(nn.ReLU())
    28             self.in_channels = channels
    29         return nn.Sequential(*layers)
    30 
    31     def forward(self, x):
    32         x = self.conv3_64(x)
    33         x = F.max_pool2d(x, 2)
    34         x = self.conv3_128(x)
    35         x = F.max_pool2d(x, 2)
    36         x = self.conv3_256(x)
    37         x = F.max_pool2d(x, 2)
    38         x = self.conv3_512a(x)
    39         x = F.max_pool2d(x, 2)
    40         x = self.conv3_512b(x)
    41         x = F.max_pool2d(x, 2)
    42         #x = x.view(x.size(0), -1)
    43         x = self.flatten(x)
    44         x = self.fc1(x)
    45         x = self.bn1(x)
    46         x = F.relu(x)
    47         x = self.fc2(x)
    48         x = self.bn1(x)
    49         x = F.relu(x)
    50         x = self.fc3(x)
    51         return x
    52 
    53 def VGG_11():
    54     return VGG([1,1,2,2,2], num_classes=6)
    55 
    56 def VGG_13():
    57     return VGG([2,2,2,2,2], num_classes=7)
    58 
    59 def VGG_16():
    60     return VGG([2,2,3,3,3], num_classes=8)
    61 
    62 def VGG_19():
    63     return VGG([2,2,4,4,4], num_classes=9)
    64 
    65 net = VGG_16().cuda()
    66 summary(net, (3, 224, 224))
  • 相关阅读:
    docker (centOS 7) 使用笔记3
    docker (centOS 7) 使用笔记4
    docker (centOS 7) 使用笔记2
    docker (centOS 7) 使用笔记1
    docker (centOS 7) 使用笔记3
    CentOS7 修改时区、charset
    p12(PKCS12)和jks互相转换
    tomcat7 日志设置为log4j
    Redis概述与基本操作
    Django学习笔记之安全
  • 原文地址:https://www.cnblogs.com/timelesszxl/p/14555192.html
Copyright © 2011-2022 走看看