zoukankan      html  css  js  c++  java
  • 深度学习 对抗生成网络 使用生成对抗网络生成图片

     这是最新找到的  对抗生成网络的代码,亲测可以跑通。前几天也上传了一个网上找到的代码,但是这回这个代码中判别网络的假数据中加入了 detach() 函数, 网上查找说这个函数可以切断神经网络的反向传导,虽然不是很理解,但总是感觉这个更对一些。对于  detach 这个函数在这里面的作用网上怎么说的都有,不过个人感觉最有说服力的说法是 减少没有必要的运算,毕竟在判别网络中我们是不需要修改生成网络的参数的,也就是说这个时候求解生成网络的梯度,对其进行反向求导是没有必要的,而这个说法和代码中的注释部分相合。

    #encoding:UTF-8
    
    
    #读入CIFAR-10数据
    from torch.utils.data import DataLoader
    from torchvision.datasets import CIFAR10
    import torchvision.transforms as transforms
    from torchvision.utils import save_image
    
    dataset = CIFAR10(root='./data', download=True,
            transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    for batch_idx, data in enumerate(dataloader):
        real_images, _ = data
        batch_size = real_images.size(0)
        print ('#{} has {} images.'.format(batch_idx, batch_size))
        if batch_idx % 100 == 0:
            path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
            save_image(real_images, path, normalize=True)
    
    
    #搭建生成网络和鉴别网络
    import torch.nn as nn
    
    # 搭建生成网络
    latent_size = 64 # 潜在大小
    n_channel = 3 # 输出通道数
    n_g_feature = 64 # 生成网络隐藏层大小
    gnet = nn.Sequential( 
            # 输入大小 = (64, 1, 1)
            nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4,
                    bias=False),
            nn.BatchNorm2d(4 * n_g_feature),
            nn.ReLU(),
            # 大小 = (256, 4, 4)
            nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4,
                    stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2 * n_g_feature),
            nn.ReLU(),
            # 大小 = (128, 8, 8)
            nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4,
                    stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_g_feature),
            nn.ReLU(),
            # 大小 = (64, 16, 16)
            nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4,
                    stride=2, padding=1),
            nn.Sigmoid(),
            # 图片大小 = (3, 32, 32)
            )
    print (gnet)
    
    
    # 搭建鉴别网络
    n_d_feature = 64 # 鉴别网络隐藏层大小
    dnet = nn.Sequential( 
            # 图片大小 = (3, 32, 32)
            nn.Conv2d(n_channel, n_d_feature, kernel_size=4,
                    stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 大小 = (64, 16, 16)
            nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4,
                    stride=2, padding=1, bias=False),
            nn.BatchNorm2d(2 * n_d_feature),
            nn.LeakyReLU(0.2),
            # 大小 = (128, 8, 8)
            nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4,
                    stride=2, padding=1, bias=False),
            nn.BatchNorm2d(4 * n_d_feature),
            nn.LeakyReLU(0.2),
            # 大小 = (256, 4, 4)
            nn.Conv2d(4 * n_d_feature, 1, kernel_size=4),
            # 对数赔率张量大小 = (1, 1, 1)
            )
    print(dnet)        
    
    
    import torch.nn.init as init
    #初始化权重值
    def weights_init(m): # 用于初始化权重值的函数
        if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
            init.xavier_normal_(m.weight)
        elif type(m) == nn.BatchNorm2d:
            init.normal_(m.weight, 1.0, 0.02)
            init.constant_(m.bias, 0)
    
    gnet.apply(weights_init)
    dnet.apply(weights_init)
    
    
    
    #主程序
    import torch
    import torch.optim
    
    # 损失
    criterion = nn.BCEWithLogitsLoss()
    
    # 优化器
    goptimizer = torch.optim.Adam(gnet.parameters(),
            lr=0.0002, betas=(0.5, 0.999))
    doptimizer = torch.optim.Adam(dnet.parameters(),
            lr=0.0002, betas=(0.5, 0.999))
    
    # 用于测试的固定噪声,用来查看相同的潜在张量在训练过程中生成图片的变换
    batch_size = 64
    fixed_noises = torch.randn(batch_size, latent_size, 1, 1)
    
    # 训练过程
    epoch_num = 10
    for epoch in range(epoch_num):
        for batch_idx, data in enumerate(dataloader):
            # 载入本批次数据
            real_images, _ = data
            batch_size = real_images.size(0)
    
            # 训练鉴别网络
            labels = torch.ones(batch_size) # 真实数据对应标签为1
            preds = dnet(real_images) # 对真实数据进行判别
            outputs = preds.reshape(-1)
            dloss_real = criterion(outputs, labels) # 真实数据的鉴别器损失
            dmean_real = outputs.sigmoid().mean()
                    # 计算鉴别器将多少比例的真数据判定为真,仅用于输出显示
    
            noises = torch.randn(batch_size, latent_size, 1, 1) # 潜在噪声
            fake_images = gnet(noises) # 生成假数据
            labels = torch.zeros(batch_size) # 假数据对应标签为0
            fake = fake_images.detach()
                    # 使得梯度的计算不回溯到生成网络,可用于加快训练速度.删去此步结果不变
            preds = dnet(fake) # 对假数据进行鉴别
            outputs = preds.view(-1)
            dloss_fake = criterion(outputs, labels) # 假数据的鉴别器损失
            dmean_fake = outputs.sigmoid().mean()
                    # 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
    
            dloss = dloss_real + dloss_fake # 总的鉴别器损失
            dnet.zero_grad()
            dloss.backward()
            doptimizer.step()
    
            # 训练生成网络
            labels = torch.ones(batch_size)
                    # 生成网络希望所有生成的数据都被认为是真数据
            preds = dnet(fake_images) # 把假数据通过鉴别网络
            outputs = preds.view(-1)
            gloss = criterion(outputs, labels) # 真数据看到的损失
            gmean_fake = outputs.sigmoid().mean()
                    # 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
            gnet.zero_grad()
            gloss.backward()
            goptimizer.step()
    
            # 输出本步训练结果
            print('[{}/{}]'.format(epoch, epoch_num) +
                    '[{}/{}]'.format(batch_idx, len(dataloader)) +
                    '鉴别网络损失:{:g} 生成网络损失:{:g}'.format(dloss, gloss) +
                    '真数据判真比例:{:g} 假数据判真比例:{:g}/{:g}'.format(
                    dmean_real, dmean_fake, gmean_fake))
            if batch_idx % 100 == 0:
                fake = gnet(fixed_noises) # 由固定潜在张量生成假数据
                save_image(fake, # 保存假数据
                        './data/images_epoch{:02d}_batch{:03d}.png'.format(
                        epoch, batch_idx))

  • 相关阅读:
    Python学习第15天_模块
    Python学习第14天_文件读取写入
    Python学习第13天_练习(图书馆的创建)
    Python学习第12天_类
    Python学习第11天_参数
    Python学习第10天_函数
    Python学习第九天_模块的应用
    Android Bluetooth HIDL服务分析
    Mac下CLion配置Google GTest小结
    MacOS通过homebrew安装老版本的软件
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/9890287.html
Copyright © 2011-2022 走看看