zoukankan      html  css  js  c++  java
  • pytorch 不使用转置卷积来实现上采样

    上采样(upsampling)一般包括2种方式:

    第二种方法如何用pytorch实现可见上面的链接

    这里想要介绍的是如何使用pytorch实现第一种方法:

    举例:

    1)使用torch.nn模块实现一个生成器为:

    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ConvLayer(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride):
            super(ConvLayer, self).__init__()
            padding = kernel_size // 2
            self.reflection_pad = nn.ReflectionPad2d(padding)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
        def forward(self, x):
            out = self.reflection_pad(x)
            out = self.conv(out)
    
            return out
    
    class Generator(nn.Module):
        def __init__(self, in_channels):
            super(Generator, self).__init__()
            self.in_channels = in_channels
    
            self.encoder = nn.Sequential(
                ConvLayer(self.in_channels, 32, 3, 2),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                ConvLayer(32, 64, 3, 2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                ConvLayer(64, 128, 3, 2),
            )
    
            upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.decoder = nn.Sequential(
                upsample,
                nn.Conv2d(128, 64, 1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                upsample,
                nn.Conv2d(64, 32, 1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                upsample,
                nn.Conv2d(32, 3, 1),
                nn.Tanh()
            )
    
        def forward(self, x):
            x = self.encoder(x)
            out = self.decoder(x)
    
            return out
    
    def test():
        net = Generator(3)
        for module in net.children():
            print(module)
        x = Variable(torch.randn(2,3,224,224))
        output = net(x)
        print('output :', output.size())
        print(type(output))
    
    if __name__ == '__main__':
        test()
    View Code

    返回:

    model.py .Sequential(
      (0): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
      )
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      )
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      )
    )
    Sequential(
      (0): Upsample(scale_factor=2, mode=bilinear)
      (1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Upsample(scale_factor=2, mode=bilinear)
      (5): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU()
      (8): Upsample(scale_factor=2, mode=bilinear)
      (9): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
      (10): Tanh()
    )
    output : torch.Size([2, 3, 224, 224])
    <class 'torch.Tensor'>
    View Code

    但是这个会有警告:

     UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.

    可使用torch.nn.functional模块替换为:

    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ConvLayer(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride):
            super(ConvLayer, self).__init__()
            padding = kernel_size // 2
            self.reflection_pad = nn.ReflectionPad2d(padding)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
        def forward(self, x):
            out = self.reflection_pad(x)
            out = self.conv(out)
    
            return out
    
    class Generator(nn.Module):
        def __init__(self, in_channels):
            super(Generator, self).__init__()
            self.in_channels = in_channels
    
            self.encoder = nn.Sequential(
                ConvLayer(self.in_channels, 32, 3, 2),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                ConvLayer(32, 64, 3, 2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                ConvLayer(64, 128, 3, 2),
            )
    
            self.decoder1 = nn.Sequential(
                nn.Conv2d(128, 64, 1),
                nn.BatchNorm2d(64),
                nn.ReLU()
            )
            self.decoder2 = nn.Sequential(
                nn.Conv2d(64, 32, 1),
                nn.BatchNorm2d(32),
                nn.ReLU()
            )
            self.decoder3 = nn.Sequential(
                nn.Conv2d(32, 3, 1),
                nn.Tanh()
            )
    
        def forward(self, x):
            x = self.encoder(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            x = self.decoder1(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            x = self.decoder2(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
            out = self.decoder3(x)
    
            return out
    
    def test():
        net = Generator(3)
        for module in net.children():
            print(module)
        x = Variable(torch.randn(2,3,224,224))
        output = net(x)
        print('output :', output.size())
        print(type(output))
    
    if __name__ == '__main__':
        test()
    View Code

    返回:

    model.py .Sequential(
      (0): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
      )
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      )
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      )
    )
    Sequential(
      (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    Sequential(
      (0): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
      (1): Tanh()
    )
    output : torch.Size([2, 3, 224, 224])
    <class 'torch.Tensor'>
    View Code
  • 相关阅读:
    《C程序设计语言现代方法》第5章 选择语句
    《C语言程序设计现代方法》第4章 编程题
    《C语言程序设计现代方法》第4章 表达式
    《算法竞赛入门经典》第1章 程序设计入门
    《C语言程序设计现代方法》第3章 格式化输入/输出
    《C语言程序设计现代方法》第2章 编程题
    《C语言程序设计现代方法》第2章 C语言基本概念
    《C语言程序设计现代方法》第1章 C语言概述
    Linux和Windows下的进程管理总结
    silvetlight ListBox Item项自动填满
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/11400866.html
Copyright © 2011-2022 走看看