zoukankan      html  css  js  c++  java
  • 使用GAN生成图片

    一:卷积神经网络的搭建

    class NetG(nn.Module):
        '''
        生成器定义
        '''
    
        def __init__(self, opt):
            super(NetG, self).__init__()
            ngf = opt.ngf  # 生成器feature map数
    
            self.maina = nn.Sequential(
                # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
                nn.ConvTranspose2d(opt.nz, ngf * 16, 4, 1, 0, bias=False),
                #noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
                nn.BatchNorm2d(ngf * 16),
                nn.ReLU(True),
    
                # 上一步的输出形状:(ngf*8) x 4 x 4
                #
                nn.ConvTranspose2d(ngf * 16, ngf * 10, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 10),
                nn.ReLU(True),
                # # # 上一步的输出形状: (ngf*4) x 8 x 8
                # # #
                nn.ConvTranspose2d(ngf * 10, ngf * 8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                # # # 上一步的输出形状: (ngf*2) x 16 x 16
                # #
                nn.ConvTranspose2d(ngf * 8, ngf*8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ngf*8),
                nn.ReLU(True),
                # # # 上一步的输出形状:(ngf) x 32 x 32
                nn.ConvTranspose2d(ngf * 8, ngf * 6, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 6),
                nn.ReLU(True),
                # #
                nn.ConvTranspose2d(ngf * 6, ngf , 4, 1, 1, bias=False),
                nn.BatchNorm2d(ngf ),
                nn.ReLU(True),
                nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
                nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
                # 输出形状:3 x 96 x 96
            )
    
        def forward(self, input):
            return self.maina(input)
    

      

    class NetD(nn.Module):
        '''
        判别器定义
        '''
    
        def __init__(self, opt):
            super(NetD, self).__init__()
            ndf = opt.ndf
            self.main = nn.Sequential(
                # 输入 3 x 96 x 96
                nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf) x 32 x 32
    
                nn.Conv2d(ndf, ndf * 6, 4, 1, 1, bias=False),
                nn.BatchNorm2d(ndf * 6),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*2) x 16 x 16
                #
                nn.Conv2d(ndf * 6, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                # # 输出 (ndf*4) x 8 x 8
                #
                nn.Conv2d(ndf * 8, ndf * 8, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 8, ndf * 10, 4, 3, 1, bias=False),
                nn.BatchNorm2d(ndf * 10),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 10, ndf * 16, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 16),
                nn.LeakyReLU(0.2, inplace=True),
                # 输出 (ndf*8) x 4 x 4
                #
                nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()  # 输出一个数(概率)
            )
    
        def forward(self, input):
            return self.main(input).view(-1)
    

      训练后生成的图片

     

  • 相关阅读:
    SAP GUI登陆 安全性提示  出现乱码
    获取sap登陆用户名的中文描述
    jQuery页面替换+php代码实现搜索后分页
    Linux下更新Git
    DirectoryInfo.GetFiles 方法 (String, SearchOption)
    WF中创建持久化服务和跟踪服务数据库
    ConfigurationSection自定义配置的使用
    微软企业库5.0系统(一):使用缓存 Microsoft.Practices.EnterpriseLibrary.Caching(初级篇)
    HttpWebRequest传输Cookie
    jquery优化规则
  • 原文地址:https://www.cnblogs.com/dudu1992/p/9110287.html
Copyright © 2011-2022 走看看