zoukankan      html  css  js  c++  java
  • 使用GAN生成图片

    一:卷积神经网络的搭建

    class NetG(nn.Module):
        '''
        生成器定义
        '''
    
        def __init__(self, opt):
            super(NetG, self).__init__()
            ngf = opt.ngf  # 生成器feature map数
    
            self.maina = nn.Sequential(
                # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
                nn.ConvTranspose2d(opt.nz, ngf * 16, 4, 1, 0, bias=False),
                #noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
                nn.BatchNorm2d(ngf * 16),
                nn.ReLU(True),
    
                # 上一步的输出形状:(ngf*8) x 4 x 4
                #
                nn.ConvTranspose2d(ngf * 16, ngf * 10, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 10),
                nn.ReLU(True),
                # # # 上一步的输出形状: (ngf*4) x 8 x 8
                # # #
                nn.ConvTranspose2d(ngf * 10, ngf * 8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                # # # 上一步的输出形状: (ngf*2) x 16 x 16
                # #
                nn.ConvTranspose2d(ngf * 8, ngf*8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ngf*8),
                nn.ReLU(True),
                # # # 上一步的输出形状:(ngf) x 32 x 32
                nn.ConvTranspose2d(ngf * 8, ngf * 6, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 6),
                nn.ReLU(True),
                # #
                nn.ConvTranspose2d(ngf * 6, ngf , 4, 1, 1, bias=False),
                nn.BatchNorm2d(ngf ),
                nn.ReLU(True),
                nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
                nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
                # 输出形状:3 x 96 x 96
            )
    
        def forward(self, input):
            return self.maina(input)
    

      

    class NetD(nn.Module):
        '''
        判别器定义
        '''
    
        def __init__(self, opt):
            super(NetD, self).__init__()
            ndf = opt.ndf
            self.main = nn.Sequential(
                # 输入 3 x 96 x 96
                nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf) x 32 x 32
    
                nn.Conv2d(ndf, ndf * 6, 4, 1, 1, bias=False),
                nn.BatchNorm2d(ndf * 6),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*2) x 16 x 16
                #
                nn.Conv2d(ndf * 6, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                # # 输出 (ndf*4) x 8 x 8
                #
                nn.Conv2d(ndf * 8, ndf * 8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 8, ndf * 10, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ndf * 10),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 10, ndf * 16, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 16),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*8) x 4 x 4
                #
                nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()  # 输出一个数(概率)
            )
    
        def forward(self, input):
            return self.main(input).view(-1)
    

      训练后生成的图片

     

  • 相关阅读:
    如何使用dig命令挖掘域名解析信息
    网络地址转换 NAT 配置
    Win10 安装子系统 GUI 界面
    送给发烧友:Python条件语句的七种写法T
    这是一个可以显示Linux命令的工具
    网页游戏破解 我是武神
    仙侠道破解
    心动最新页游 仙侠道 破解笔记
    通用网页游戏伤害公式。
    mysql: error while loading shared libraries: libmysqlclient.so.16
  • 原文地址:https://www.cnblogs.com/dudu1992/p/9110287.html
Copyright © 2011-2022 走看看