zoukankan      html  css  js  c++  java
  • 使用GAN生成图片

    一:卷积神经网络的搭建

    class NetG(nn.Module):
        '''
        生成器定义
        '''
    
        def __init__(self, opt):
            super(NetG, self).__init__()
            ngf = opt.ngf  # 生成器feature map数
    
            self.maina = nn.Sequential(
                # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
                nn.ConvTranspose2d(opt.nz, ngf * 16, 4, 1, 0, bias=False),
                #noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
                nn.BatchNorm2d(ngf * 16),
                nn.ReLU(True),
    
                # 上一步的输出形状:(ngf*8) x 4 x 4
                #
                nn.ConvTranspose2d(ngf * 16, ngf * 10, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 10),
                nn.ReLU(True),
                # # # 上一步的输出形状: (ngf*4) x 8 x 8
                # # #
                nn.ConvTranspose2d(ngf * 10, ngf * 8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                # # # 上一步的输出形状: (ngf*2) x 16 x 16
                # #
                nn.ConvTranspose2d(ngf * 8, ngf*8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ngf*8),
                nn.ReLU(True),
                # # # 上一步的输出形状:(ngf) x 32 x 32
                nn.ConvTranspose2d(ngf * 8, ngf * 6, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 6),
                nn.ReLU(True),
                # #
                nn.ConvTranspose2d(ngf * 6, ngf , 4, 1, 1, bias=False),
                nn.BatchNorm2d(ngf ),
                nn.ReLU(True),
                nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
                nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
                # 输出形状:3 x 96 x 96
            )
    
        def forward(self, input):
            return self.maina(input)
    

      

    class NetD(nn.Module):
        '''
        判别器定义
        '''
    
        def __init__(self, opt):
            super(NetD, self).__init__()
            ndf = opt.ndf
            self.main = nn.Sequential(
                # 输入 3 x 96 x 96
                nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf) x 32 x 32
    
                nn.Conv2d(ndf, ndf * 6, 4, 1, 1, bias=False),
                nn.BatchNorm2d(ndf * 6),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*2) x 16 x 16
                #
                nn.Conv2d(ndf * 6, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                # # 输出 (ndf*4) x 8 x 8
                #
                nn.Conv2d(ndf * 8, ndf * 8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 8, ndf * 10, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ndf * 10),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 10, ndf * 16, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 16),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*8) x 4 x 4
                #
                nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()  # 输出一个数(概率)
            )
    
        def forward(self, input):
            return self.main(input).view(-1)
    

      训练后生成的图片

     

  • 相关阅读:
    使用vue-cli4 F5刷新 报错:Uncaught SyntaxError: Unexpected token <
    visual stuido 删除虚拟目录
    visual studio2017 使用内联变量 编译失败 但不显示错误信息
    spring boot 整合CXF创建web service
    .net 解析以竖线,美元符等分割符分割的字符串为实体
    c# 动态构造实体属性的lambda Expression表达式
    spring boot 创建web service及调用
    JPA使用log4jdbc输出sql日志
    JPA 使用log4j2输出SQL日志到文件
    JPA 使用logback输出SQL日志到文件
  • 原文地址:https://www.cnblogs.com/dudu1992/p/9110287.html
Copyright © 2011-2022 走看看