zoukankan      html  css  js  c++  java
  • pytorch-cifar10分类网络结构

    cifar10主要是由32x32的三通道彩色图, 总共10个类别,这里我们使用残差网络构造网络结构

    网络结构:

    第一层:首先经过一个卷积,归一化,激活 32x32x16 -> 32x32x16

    第二层:  通过一多个残差模型

    残差模块的网络构造:

               如果stride != 1 or in_channel != out_channel, 就构造downsample网络结构进行降采样操作

               利用残差模块进行第一次残差卷积, 将downsample传入

              连续进行多次的残差卷积

    from torchvision import transforms
    from torch import nn
    # 首先对图片进行数据转换
    
    train_transform = transforms.Compose([
        transforms.Scale(40), # 相当于是resize操作,
        transforms.RandomHorizontalFlip(), # 表示进行左右的翻转
        transforms.RandomCrop(32), #表示进行随机的裁剪
        transforms.ToTensor(), # 将数据转换为tensor格式
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 进行-均值 / 标准差, 将数据转换为-1, 1 之间
    
    ])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    def conv3x3(in_channels, out_channels, stride=1):
        return nn.Conv2d(in_channels,
                         out_channels,
                         kernel_size=3,
                         stride=stride,
                         padding=1,
                         bias=False)
    
    class ResidualBlock(nn.Module):
        def __init__(self, in_channels, out_channels, stride=1, downsample=None):
            super(ResidualBlock, self).__init__()
            self.conv1 = conv3x3(in_channels, out_channels, stride=1)
            self.bn = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU(True)
            self.conv2 = conv3x3(out_channels, out_channels, stride=1)
            self.bn = nn.BatchNorm2d(out_channels)
            self.downsample = downsample
    
        def forward(self, x):
            residual = x
            out = self.conv1(x)
            out = self.bn(x)
            out = self.relu(x)
            out = self.conv2(x)
            out = self.bn(x)
            if self.downsample:
                residual = self.downsample(x)
            out += residual
            return self.relu(out)
    
    
    class ResNet(nn.Module):
        def __init__(self, block, layers, num_classes=10):
            super(ResNet, self).__init__()
            self.in_channels = 16
            self.conv = conv3x3(3, 16)
            self.bn = nn.BatchNorm2d(self.in_channels)
            self.relu = nn.ReLU(True)
            self.layers1 = self.make_block(block, 16, layers[0])
            self.layers2 = self.make_block(block, 32, layers[0])
            self.layers3 = self.make_block(block, 64, layers[1])
            self.avg_pool = nn.AvgPool2d(8)
            self.fc = nn.Linear(64, num_classes)
    
        def make_block(self, block, out_channels, blocks, stride=1):
            downsample = None
            if stride != 1 or out_channels != self.in_channels:
                downsample = nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
            layers = []
            layers.append(block(self.in_channels, out_channels, stride=stride, downsample = downsample))
            for i in blocks:
                layers.append(block(self.out_channels, out_channels, stride=stride, downsample=downsample))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            out = self.conv(x)
            out = self.bn(out)
            out = self.relu(out)
            out = self.layers1(out)
            out= self.layers2(out)
            out = self.layers3(out)
            out = self.avg_pool(out)
            out = self.fc(out)
    
            return out
  • 相关阅读:
    如何增加VM Ware虚拟机的硬盘空间
    安装完成oracle 11g R2 后,使用sqlplus 报错"sqlplus: error while loading shared libraries" ...
    listener.ora
    ExtJS项目框架有关问题讨论
    Oracle启动监听报错:The listener supports no services解决
    Linq学习笔记一
    PAT 1054 The Dominant Color[简单][运行超时的问题]
    Andrew NgML第十八章大规模机器学习
    PAT 1042 Shuffling Machine[难]
    PAT 1103 Integer Factorization[难]
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/11738505.html
Copyright © 2011-2022 走看看