zoukankan      html  css  js  c++  java
  • pytorch-VGG网络

    VGG网络结构 

    第一层: 3x3x3x64, 步长为1, padding=1 

    第二层: 3x3x64x64, 步长为1, padding=1 

    第三层: 3x3x64x128, 步长为1, padding=1

    第四层: 3x3x128x128, 步长为1, padding=1

    第五层: 3x3x128x256, 步长为1, padding=1

    第六层: 3x3x256x256, 步长为1, padding=1

    第七层: 3x3x256x256, 步长为1, padding=1

    第八层: 3x3x256x512, 步长为1, padding=1 

    第九层: 3x3x512x512, 步长为1, padding=1 

    第十层:3x3x512x512, 步长为1, padding=1 

    第十一层: 3x3x512x512, 步长为1, padding=1 

    第十二层: 3x3x512x512, 步长为1, padding=1 

    第十三层:3x3x512x512, 步长为1, padding=1 

    第十四层: 512*7*7, 4096的全连接操作

    第十五层: 4096, 4096的全连接操作

    第十六层: 4096, num_classes 的 全连接操作

    import torch
    from torch import nn
    
    class VGG(nn.Module):
        def __init__(self, num_classes):
            super(VGG, self).__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(64, 64, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),)
            self.classifier = nn.Sequential(
                nn.Linear(512*7*7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, num_classes)
            )
    
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
    
            return x
  • 相关阅读:
    C#基础笔记(第十四天)
    C#基础笔记(第十三天)
    C#基础整理(二)
    C#基础笔记(第十二天)
    C#基础笔记(第十一天)
    C#基础笔记(第十天)
    C#基础笔记(第九天)
    [PyTorch 学习笔记] 2.3 二十二种 transforms 图片数据预处理方法
    [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制
    [PyTorch 学习笔记] 2.1 DataLoader 与 DataSet
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/11729565.html
Copyright © 2011-2022 走看看