zoukankan      html  css  js  c++  java
  • torch jupyter写的WGAN-GP

    训练使用的是天池的GPU,没有用到tensorBroadX,用的是静态的plt显示Loss。

    我通过实验证明了同样的网络使用WGAN-GP的网络架构对比没有使用WGAN-GP的网络架构会有更加不容易模型崩溃(model collapse)。

    下面是代码,使用的数据集CIFAR10。

    utils.py文件在https://www.cnblogs.com/abc23/p/14390153.html
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import PIL.Image as Image
    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    import torchvision
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    import utils
    from torch.autograd import grad
    from torch.autograd import Variable
    
    import functools
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision.utils as vutils
    %matplotlib inline
    # import torchlib
    """ gpu """
    gpu_id = [0]
    utils.cuda_devices(gpu_id)
    
    
    """ param """
    epochs = 500
    batch_size = 128
    lr = 0.0002
    n_critic = 5
    z_dim = 100
    # 决定我们在哪个设备上运行
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    def gradient_penalty(x, y, f):
        # interpolation
        shape = [x.size(0)] + [1] * (x.dim() - 1)
        alpha = utils.cuda(torch.rand(shape))
        z = x + alpha * (y - x)
    
        # gradient penalty
        z = utils.cuda(Variable(z, requires_grad=True))
        o = f(z)
        g = grad(o, z, grad_outputs=utils.cuda(torch.ones(o.size())), create_graph=True)[0].view(z.size(0), -1)
        gp = ((g.norm(p=2, dim=1) - 1)**2).mean()
    
        return gp
    """ data """
    crop_size = 108
    re_size = 64
    offset_height = (218 - crop_size) // 2
    offset_width = (178 - crop_size) // 2
    crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_offset_width + crop_size]
    
    transform = transforms.Compose([transforms.Resize(re_size),
                                   transforms.CenterCrop(re_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
    
    dataset = torchvision.datasets.CIFAR10(root='dataset/', train=True, transform=transform, download=True)
    
    # imagenet_data = dsets.ImageFolder('./data/img_align_celeba', transform=transform)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=True)
    # 展示一些训练图片
    real_batch = next(iter(data_loader))
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
    class LayerNorm(nn.Module):
    
        def __init__(self, num_features, eps=1e-5, affine=True):
            super(LayerNorm, self).__init__()
            self.num_features = num_features
            self.affine = affine
            self.eps = eps
    
            if self.affine:
                self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
                self.beta = nn.Parameter(torch.zeros(num_features))
    
        def forward(self, x):
            # This implementation is too slow!!!
    
            shape = [-1] + [1] * (x.dim() - 1)
            mean = x.view(x.size(0), -1).mean(1).view(*shape)
            std = x.view(x.size(0), -1).std(1).view(*shape)
            y = (x - mean) / (std + self.eps)
            if self.affine:
                shape = [1, -1] + [1] * (x.dim() - 2)
                y = self.gamma.view(*shape) * y + self.beta.view(*shape)
            return y
    class Generator(nn.Module):
    
        def __init__(self, in_dim, dim=64):
            super(Generator, self).__init__()
    
            def dconv_bn_relu(in_dim, out_dim):
                return nn.Sequential(
                    nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                       padding=2, output_padding=1, bias=False),
                    nn.BatchNorm2d(out_dim),
                    nn.ReLU())
    
            self.l1 = nn.Sequential(
                nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
                nn.BatchNorm1d(dim * 8 * 4 * 4),
                nn.ReLU())
            self.l2_5 = nn.Sequential(
                dconv_bn_relu(dim * 8, dim * 4),
                dconv_bn_relu(dim * 4, dim * 2),
                dconv_bn_relu(dim * 2, dim),
                nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
                nn.Tanh())
    
        def forward(self, x):
            y = self.l1(x)
            y = y.view(y.size(0), -1, 4, 4)
            y = self.l2_5(y)
            return y
    class DiscriminatorWGANGP(nn.Module):
    
        def __init__(self, in_dim, dim=64):
            super(DiscriminatorWGANGP, self).__init__()
    
            def conv_ln_lrelu(in_dim, out_dim):
                return nn.Sequential(
                    nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                    # Since there is no effective implementation of LayerNorm,
                    # we use InstanceNorm2d instead of LayerNorm here.
                    nn.InstanceNorm2d(out_dim, affine=True),
                    nn.LeakyReLU(0.2))
    
            self.ls = nn.Sequential(
                nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2),
                conv_ln_lrelu(dim, dim * 2),
                conv_ln_lrelu(dim * 2, dim * 4),
                conv_ln_lrelu(dim * 4, dim * 8),
                nn.Conv2d(dim * 8, 1, 4))
    
        def forward(self, x):
            y = self.ls(x)
            y = y.view(-1)
            return y
    D = DiscriminatorWGANGP(3)
    G = Generator(z_dim)
    bce = nn.BCEWithLogitsLoss()
    utils.cuda([D, G, bce])
    d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    """ load checkpoint """
    ckpt_dir = './checkpoints2/celeba_wgan_gp'
    utils.mkdir(ckpt_dir)
    try:
        ckpt = utils.load_checkpoint(ckpt_dir)
        start_epoch = ckpt['epoch']
        D.load_state_dict(ckpt['D'])
        G.load_state_dict(ckpt['G'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
    except:
        print(' [*] No checkpoint!')
        start_epoch = 0
    img_list = []
    G_losses = []
    D_losses = []
    z_sample = Variable(torch.randn(100, z_dim))
    z_sample = utils.cuda(z_sample)
    for epoch in range(start_epoch, epochs):
        for i, (imgs, _) in enumerate(data_loader):
            
            # step
            step = epoch * len(data_loader) + i + 1
    
            # set train
            G.train()
    
            # leafs
            imgs = Variable(imgs)
            bs = imgs.size(0)
            z = Variable(torch.randn(bs, z_dim))
            imgs, z = utils.cuda([imgs, z])
    
            f_imgs = G(z)
    
            # train D
            r_logit = D(imgs)
            f_logit = D(f_imgs.detach())
    
            wd = r_logit.mean() - f_logit.mean()  # Wasserstein-1 Distance
            gp = gradient_penalty(imgs.data, f_imgs.data, D)
            d_loss = -wd + gp * 10.0
    
            D.zero_grad()
            d_loss.backward()
            d_optimizer.step()
    
            if step % n_critic == 0:
                # train G
                z = utils.cuda(Variable(torch.randn(bs, z_dim)))
                f_imgs = G(z)
                f_logit = D(f_imgs)
                g_loss = -f_logit.mean()
                
                D.zero_grad()
                G.zero_grad()
                g_loss.backward()
                g_optimizer.step()
                
                # 为以后画损失图,保存损失
                G_losses.append(g_loss.item())
                D_losses.append(d_loss.item())
            
            if (i + 1) % 150 == 0:
                print("Epoch: (%3d) (%5d/%5d)" % (epoch, i + 1, len(data_loader)))
    
            if (i + 1) % 150 == 0:
                G.eval()
                f_imgs_sample = (G(z_sample).data + 1) / 2.0
    
                save_dir = './sample_images_while_training/celeba_wgan_gp'
                utils.mkdir(save_dir)
                torchvision.utils.save_image(f_imgs_sample, '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, epoch, i + 1, len(data_loader)), nrow=10)
            if (i % 150 == 0):
                with torch.no_grad():
                    fake = G(z).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        utils.save_checkpoint({'epoch': epoch + 1,
                               'D': D.state_dict(),
                               'G': G.state_dict(),
                               'd_optimizer': d_optimizer.state_dict(),
                               'g_optimizer': g_optimizer.state_dict()},
                              '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch + 1),
                              max_keep=2)
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    # 从数据加载器中获取一批真实图像
    real_batch = next(iter(data_loader))
    
    # 画出真实图像
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
    
    # 画出来自最后一次训练的假图像
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    plt.show()
  • 相关阅读:
    webpack2使用ch4-向根目录index.html文件传参并使用参数 使用线上资源 压缩html
    webpack2使用ch3-自动化生成.html和内部引入的js自动更改
    webpack2使用ch2-entry和output简要说明
    webpack2使用ch1-目录说明
    less使用ch1--简单使用
    less使用ch1--认识语法
    vue2购物车ch4-(筛选v-for 点击的那个设置样式 设为默认地址其他 联动 非循环的列表选中和非选中 删除当前选中的列表)
    gulp使用2-gulp-less及watch和错误提示
    gulp使用1-入门指南
    vue2购物车ch3-(过滤器使用 单件商品金额计算 全选全不选 总金额计算 删除商品功能)
  • 原文地址:https://www.cnblogs.com/abc23/p/14390100.html
Copyright © 2011-2022 走看看