zoukankan      html  css  js  c++  java
  • 【pytorch】DCGAN实战教程(官方教程)


    1. 简介

    本教程将通过一个具体的实例来讲解DCGANs。我们将训练一个生成对抗性网络(GAN),在向其展示许多真正名人的照片后,该网络能产生新的名人。此处的大部分代码都来自pytorch/examples中的dcgan实现,本文将对实现方式进行详细的讲解,并阐明该模型如何以及为什么起作用。你之前并不了解GAN也没关系,但对于新手的话可能需要花费一些时间来理解幕后的实际情况。同样,如果有一两个GPU的话,将会帮助你节省训练时间。让我们开始吧。

    2. 概述

    2.1. 什么是GAN(生成对抗网络)

    GANs是一个深度学习模型框架,用于获取训练数据的分布,这样我们就可以从同一分布中生成新的数据。GANs是Ian Goodfellow在2014年提出的,并在论文Generative Adversarial Nets中进行了首次描述。

    它们由两个不同的模型组成,分别是生成器和判别器。生成器的工作是生成看起来像训练图像的假图。判别器的任务是判别一张图像是真实的训练图像还是来自生成器的伪图像。在训练过程中,生成器通过生成越来越像真实图像的伪图来尝试骗过判别器,而判别器则是努力地想成为更好的侦探,这样才能正确地对真实和伪造的图像进行分类。

    博弈的平衡点是当生成器生成的伪造图像看起来像直接来自训练数据,而判别器始终以50%的置信度推测生成器的输出是真的还是假的。

    现在,让我们从判别器开始定义一些在整个教程中都会使用的符号。令 x x x 为图像数据, D ( x ) D(x) D(x)是判别器网络输出 x x x来自训练数据而不是生成器的概率。由于我们要处理图像,因此 D ( x ) D(x) D(x)的输入是CHW大小为3x64x64的图像。直观地说,当 x x x来自训练数据时, D ( x ) D(x) D(x)的值应该高;当 x x x来自生成生成器时, D ( x ) D(x) D(x)的值应该低。 D ( x ) D(x) D(x)其实也可以看作是传统的二分类器。

    对于生成器的表示法,令 z z z为从标准正态分布采样的潜在空间向量。 G ( z ) G(z) G(z)表示将潜在空间向量 z z z映射到数据空间的生成器函数。 G G G的目标是估计训练数据分布( p d a t a p_{data} pdata),以便它可以从估计的数据分布( p g p_g pg)中生成假样本。

    因此, D ( G ( z ) ) D(G(z)) D(G(z))是生成器 G G G的输出为真实图像的概率值(标量)。正如Goodfellow论文所描述的, D D D G G G玩一个minimax的游戏,其中 D D D尝试使它能正确分类真图和伪图的概率最大化( l o g D ( x ) logD(x) logD(x)),而 G G G却尝试使 D D D预测其输出是伪图的概率最小化( l o g ( 1 − D ( G ( x ) ) ) log(1-D(G(x))) log(1D(G(x))))。论文中,GAN的损失函数是:
    min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] underset{G}{ ext{min}} underset{D}{ ext{max}}V(D,G) = mathbb{E}_{xsim p_{data}(x)}ig[logD(x)ig] + mathbb{E}_{zsim p_{z}(z)}ig[log(1-D(G(z)))ig] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
    从理论上讲,此minimax游戏的最终解决方案是 p g = p d a t a p_g = p_{data} pg=pdata,并且判别器会随机猜测输入的图像是真还是假。但是GANs的收敛理论仍在积极地研究中,实际上模型也并不总是能够达到这一点。

    2.2. 什么是DCGAN(深度卷积生成对抗网络)

    DCGAN是上述讲的GAN的一个分支,不同的是DCGAN分别在判别器和生成器中使用卷积和反卷积层。它最初是由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中提出的。

    判别器由卷积层、批标准化层、以及LeakyReLU激活函数组成。输入是一张3x64x64的图片,输出是该图来自真实数据分布的标量概率值。

    生成器由反卷积层、批标准化层、以及ReLU激活函数组成。输入是一个来自标准正分布的潜在空间向量 z z z,输出是一个3x64x64RGB彩色图片。反置卷积层将潜在空间向量转换为具有与真实图像相同的维度。论文中,作者还提供了有关如何设置优化器,如何计算损失函数,以及如何初始化模型权重的一些技巧,所有这些将在接下来的部分中进行讲解。

    from __future__ import print_function
    #%matplotlib inline
    import argparse
    import os
    import random
    import torch
    import torch.nn as nn
    import torch.nn.parallel
    import torch.backends.cudnn as cudnn
    import torch.optim as optim
    import torch.utils.data
    import torchvision.datasets as dset
    import torchvision.transforms as transforms
    import torchvision.utils as vutils
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    from IPython.display import HTML
    
    # Set random seed for reproducibility
    manualSeed = 999
    #manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    

    输出:

    Random Seed:  999
    

    3. 输入

    • dataroot:数据集文件夹所在的路径
    • workers :数据加载器加载数据的线程数
    • batch_size:训练的批次大小。DCGAN论文中用的是128
    • image_size:训练图像的维度。默认是64x64。如果需要其它尺寸,必须更改 D D D G G G的结构,点击这里查看详情
    • nc:输入图像的通道数。对于彩色图像是3
    • nz:潜在空间的长度
    • ngf:与通过生成器进行的特征映射的深度有关
    • ndf:设置通过鉴别器传播的特征映射的深度
    • num_epochs:训练的总轮数。训练的轮数越多,可能会导致更好的结果,但也会花费更长的时间
    • lr:学习率。DCGAN论文中用的是0.0002
    • beta1:Adam优化器的参数beta1。论文中,值为0.5
    • ngpus:可用的GPU数量。如果为0,代码将在CPU模式下运行;如果大于0,它将在该数量的GPU下运行
    # Root directory for dataset
    dataroot = "data/celeba"
    
    # Number of workers for dataloader
    workers = 2
    
    # Batch size during training
    batch_size = 128
    
    # Spatial size of training images. All images will be resized to this
    #   size using a transformer.
    image_size = 64
    
    # Number of channels in the training images. For color images this is 3
    nc = 3
    
    # Size of z latent vector (i.e. size of generator input)
    nz = 100
    
    # Size of feature maps in generator
    ngf = 64
    
    # Size of feature maps in discriminator
    ndf = 64
    
    # Number of training epochs
    num_epochs = 5
    
    # Learning rate for optimizers
    lr = 0.0002
    
    # Beta1 hyperparam for Adam optimizers
    beta1 = 0.5
    
    # Number of GPUs available. Use 0 for CPU mode.
    ngpu = 1
    

    4. 数据

    本教程中,我们将使用Celeb-A Faces数据集,该数据集可以在链接的网站或谷歌云盘中下载。数据集下载下来是一个名为img_align_celeba.zip的压缩文件。下载后,创建一个名为celeba的目录,并将zip文件解压到该目录中。然后,将dataroot设置为刚创建的目录。结果目录结构应该为:

    /path/to/celeba
        -> img_align_celeba
            -> 188242.jpg
            -> 173822.jpg
            -> 284702.jpg
            -> 537394.jpg
               ...
    

    这是重要的一步,因为我们将使用ImageFolder数据集类,该类要求数据集的根文件夹中有子目录。现在,我们可以创建数据集、数据加载器,以及设置训练的设备,最后可视化一些训练数据。

    # We can use an image folder dataset the way we have it setup.
    # Create the dataset
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    
    # Decide which device we want to run on
    device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
    
    # Plot some training images
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
    

    5. 实现

    设置输入参数并准备好数据集后,我们可以进入实现了。我们将从权重初始化策略开始,然后详细的讨论生成器、判别器、损失函数和训练过程。

    5.1. 权重初始化

    在DCGAN论文中,作者指出所有模型权重应当从均值为0,标准差为0.02的正态分布中随机初始化。weights_init函数以初始化的模型为输入,重新初始化所有卷积层、反卷积层和批标准化层,以满足这一标准。该函数在初始化后立即应用于模型。

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    

    5.2. 生成器

    生成器 G G G用于将潜在空间向量 z z z映射到数据空间。由于我们的数据是图像,因此将 z z z转换到数据空间意味着最终创建与训练图像大小相同的RGB图像(即3x64x64)。

    实际上,这是通过一系列的二维反卷积层来完成的,每层都配带有批标准化层和relu激活。生成器的输出最终经过tanh函数处理,以使其返回到[-1, 1]的输入数据范围。

    值得注意的是,在反卷积层之后存在批标准化函数,这是DCGAN论文中的关键贡献。这些层有助于训练过程中的梯度流动,DCGAN论文中生成器的一张图片如下。

    注意,我们在输入部分中设置的输入(nzngfnc)如何影响代码中的生成器体系结构。 nz是输入向量 z z z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像的通道数(对于RGB图像来说是3)。 下面是生成器的代码。

    # Generator Code
    
    class Generator(nn.Module):
        def __init__(self, ngpu):
            super(Generator, self).__init__()
            self.ngpu = ngpu
            self.main = nn.Sequential(
                # input is Z, going into a convolution
                nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                # state size. (ngf*8) x 4 x 4
                nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 4),
                nn.ReLU(True),
                # state size. (ngf*4) x 8 x 8
                nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 2),
                nn.ReLU(True),
                # state size. (ngf*2) x 16 x 16
                nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf),
                nn.ReLU(True),
                # state size. (ngf) x 32 x 32
                nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
                nn.Tanh()
                # state size. (nc) x 64 x 64
            )
    
        def forward(self, input):
            return self.main(input)
    

    现在,我们可以实例化生成器并应用weights_init函数。检查打印的模型以查看生成器对象的结构。

    # Create the generator
    netG = Generator(ngpu).to(device)
    
    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))
    
    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netG.apply(weights_init)
    
    # Print the model
    print(netG)
    

    输出:

    Generator(
      (main): Sequential(
        (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
        (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): ReLU(inplace=True)
        (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (11): ReLU(inplace=True)
        (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (13): Tanh()
      )
    )
    

    5.3. 判别器

    如前所述,判别器 D D D是一个二分类网络,该网络将图像作为输入,并输出该图是真(与假相对)的标量概率。

    这里, D D D以3x64x64的图像作为输入,通过一系列的Conv2dBatchNorm2dLeakyReLU层的处理,然后通过Sigmoid激活函数输出最终概率。对于这个问题,如果需要的话,这个体系结构可以扩展更多的层,但是使用strided convolutionBatchNormLeakyReLUs具有重要意义。DCGAN论文提到,使用strided convolution而不是通过池化来进行下采样是个好方法,因为它可以让网络学习自己的池化函数。 batch normleaky relu函数还可以促进健康的梯度流动,这对于 G G G D D D的学习过程都至关重要。

    判别器代码

    class Discriminator(nn.Module):
        def __init__(self, ngpu):
            super(Discriminator, self).__init__()
            self.ngpu = ngpu
            self.main = nn.Sequential(
                # input is (nc) x 64 x 64
                nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf) x 32 x 32
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*2) x 16 x 16
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*4) x 8 x 8
                nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*8) x 4 x 4
                nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()
            )
    
        def forward(self, input):
            return self.main(input)
    

    现在和生成器一样,我们可以创建判别器,应用weights_init函数,并打印模型结构。

    # Create the Discriminator
    netD = Discriminator(ngpu).to(device)
    
    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))
    
    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)
    
    # Print the model
    print(netD)
    

    输出:

    Discriminator(
      (main): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
        (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): LeakyReLU(negative_slope=0.2, inplace=True)
        (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (7): LeakyReLU(negative_slope=0.2, inplace=True)
        (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (10): LeakyReLU(negative_slope=0.2, inplace=True)
        (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
        (12): Sigmoid()
      )
    )
    

    5.4. 损失函数和优化器

    D D D G G G设置之后,我们可以指定它们如何通过损失函数和优化器学习。我们将使用在PyTorch中定义的二元交叉熵损失(BCELoss)函数:
    ℓ ( x , y ) = L = { l 1 , … , l N } ⊤ , l n = − [ y n ⋅ log ⁡ x n + ( 1 − y n ) ⋅ log ⁡ ( 1 − x n ) ] ell(x, y) = L = {l_1,dots,l_N}^ op, quad l_n = - left[ y_n cdot log x_n + (1 - y_n) cdot log (1 - x_n) ight] (x,y)=L={l1,,lN},ln=[ynlogxn+(1yn)log(1xn)]
    注意此函数如何提供目标函数中两个对数成分的计算(即 l o g ( D ( x ) ) log(D(x)) log(D(x)) l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z))))。 我们可以指定BCE方程的哪一部分用于 y y y输入。 这是在即将到来的训练循环中完成的,但重要的是要了解如何仅通过更改 y y y(即GT标签)就可以选择想要计算的组件。

    接下来,我们将实际标签定义为1,将假标签定义为0。这些标签将在计算 D D D G G G的损失时使用,这是在原始GAN论文中使用的惯例。

    最后,我们设置了两个单独的优化器,一个针对 D D D,一个针对 G G G。正如DCGAN论文中所规定的,这两个都是lr0.0002Beta10.5Adam优化器。为了跟踪生成器的学习过程,我们将生成一批来自高斯分布的固定潜在空间向量(即fixed_noise)。在训练循环中,我们将定期地把fixed_noise输入到 G G G中,经过多次迭代,我们将看到图像从噪声中形成。

    # Initialize BCELoss function
    criterion = nn.BCELoss()
    
    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    
    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.
    
    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    

    5.5. 训练

    最后,既然定义了GAN框架的所有部分,我们就可以对其进行训练了。要注意,训练GAN网络在某种程度上来说是一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而对失败的原因几乎不可解释。

    在这里,我们将严格遵守Goodfellow论文中的算法1,同时遵守ganhacks中展示的一些最佳做法。也即是说,我们将为真图和假图构造不同的mini-batches,并调整 G G G的目标函数,使 l o g D ( G ( z ) ) logD(G(z)) logD(G(z))最大化。训练分为两个主要部分,第一部分是判别器的更新,第二部分是生成器的更新。

    5.5.1. 第一部分 - 训练判别器

    回想一下,训练判别器的目的是最大限度地提高将给定输入正确分类为真实或伪造的可能性。就像Goodfellow在论文中所说的,我们希望“通过提升其随机梯度来更新鉴别器”。

    实际上,我们想最大化 l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x))+log(1−D(G(z))) log(D(x))+log(1D(G(z)))。由于ganhacks提出了单独的mini-batch建议,因此我们将分两步进行计算。首先,我们将从训练集中构造一批真实样本,向前传播给 D D D,计算损失( l o g ( D ( x ) ) log(D(x)) log(D(x))),然后向后传播计算梯度。接着,我们将用当前的生成器构造一批假样本,将该批样本向前传播给 D D D,计算损失( l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1D(G(z)))),并向后传播累加梯度。现在,随着从所有真实批次和所有假批次累积的梯度,我们称之为判别器的优化器的一个步骤。

    5.5.2. 第二部分 - 训练生成器

    如原论文所述,我们希望通过最小化 l o g ( 1 − D ( G ( z ) ) ) log(1−D(G(z))) log(1D(G(z)))来训练生成器,以产生更好的伪造品。但又如前所述,Goodfellow表明,这不能提供足够的梯度,特别是在学习过程的早期。而解决方案是改为最大化 l o g ( D ( G ( z ) ) ) log(D(G(z))) log(D(G(z)))

    在代码中,我们的具体实现方法是:用判别器对第一部分生成器的输出进行分类,使用真图的标签作为GT计算 G G G的损失,计算 G G G在反向传播中的梯度,最后通过优化器step更新 G G G的参数。使用真图的标签作为GT来计算损失似乎是违反常识的,但这允许我们使用BCELoss l o g ( x ) log(x) log(x)部分(而不是 l o g ( 1 − x ) log(1−x) log(1x)部分),这正是我们想要的。

    最后,我们将做一些统计报告,在每个epoch结束时,我们将通过生成器推动我们的fixed_noise batch,以直观地跟踪 G G G的训练过程。 上报的训练统计数据为:

    • Loss_D - 判别器损失,计算为所有真实批次和所有假批次的损失之和 ( l o g ( D ( x ) ) + l o g ( D ( G ( z ) ) ) log(D(x))+log(D(G(z))) log(D(x))+log(D(G(z))))。
    • Loss_G - 生成器损失,计算为log(D(G(z)))。
    • D(x) - 判别器对于真实批次的平均输出(整个批次)。刚开始训练的时候这个值应该接近1,当 G G G变得更好时,理论上收敛到0.5。想想这是为什么。
    • D(G(z)) - 判别器对于假批次的平均输出。第一个数字在 D D D更新之前,第二个数字在 D D D更新之后。这些数字在开始的时候应该是接近0的,并随着 G G G的提高向0.5收敛。想想这是为什么。

    注意:此步骤可能需要一段时间。具体取决于你运行了多少个epoch以及是否从数据集中删除了一些数据。

    # Training Loop
    
    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0
    
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):
    
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()
    
            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()
    
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()
    
            # Output training stats
            if i % 50 == 0:
                print('[%d/%d][%d/%d]	Loss_D: %.4f	Loss_G: %.4f	D(x): %.4f	D(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())
    
            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
    
            iters += 1
    

    输出:

    Starting Training Loop...
    [0/5][0/1583]   Loss_D: 1.9847  Loss_G: 5.5914  D(x): 0.6004    D(G(z)): 0.6680 / 0.0062
    [0/5][50/1583]  Loss_D: 0.4017  Loss_G: 17.8778 D(x): 0.8368    D(G(z)): 0.0000 / 0.0000
    [0/5][100/1583] Loss_D: 2.8508  Loss_G: 22.8236 D(x): 0.9634    D(G(z)): 0.8460 / 0.0000
    [0/5][150/1583] Loss_D: 0.2360  Loss_G: 5.4596  D(x): 0.8440    D(G(z)): 0.0308 / 0.0090
    [0/5][200/1583] Loss_D: 1.6425  Loss_G: 4.7064  D(x): 0.3414    D(G(z)): 0.0079 / 0.0176
    [0/5][250/1583] Loss_D: 0.2731  Loss_G: 4.4791  D(x): 0.9431    D(G(z)): 0.1680 / 0.0225
    [0/5][300/1583] Loss_D: 0.6051  Loss_G: 4.6251  D(x): 0.8278    D(G(z)): 0.2424 / 0.0230
    [0/5][350/1583] Loss_D: 0.7070  Loss_G: 1.6842  D(x): 0.6204    D(G(z)): 0.0824 / 0.2560
    [0/5][400/1583] Loss_D: 0.6758  Loss_G: 4.0679  D(x): 0.9354    D(G(z)): 0.3946 / 0.0288
    [0/5][450/1583] Loss_D: 0.5348  Loss_G: 5.7453  D(x): 0.9625    D(G(z)): 0.3514 / 0.0083
    [0/5][500/1583] Loss_D: 0.6896  Loss_G: 7.8784  D(x): 0.9364    D(G(z)): 0.4080 / 0.0012
    [0/5][550/1583] Loss_D: 0.4377  Loss_G: 8.1336  D(x): 0.9425    D(G(z)): 0.2840 / 0.0007
    [0/5][600/1583] Loss_D: 1.8797  Loss_G: 2.5577  D(x): 0.3201    D(G(z)): 0.0123 / 0.1258
    [0/5][650/1583] Loss_D: 1.3832  Loss_G: 10.6947 D(x): 0.9770    D(G(z)): 0.7006 / 0.0001
    [0/5][700/1583] Loss_D: 0.3195  Loss_G: 3.7833  D(x): 0.8474    D(G(z)): 0.0844 / 0.0789
    [0/5][750/1583] Loss_D: 0.2142  Loss_G: 4.1755  D(x): 0.8942    D(G(z)): 0.0813 / 0.0232
    [0/5][800/1583] Loss_D: 1.4535  Loss_G: 2.3077  D(x): 0.4024    D(G(z)): 0.0111 / 0.1806
    [0/5][850/1583] Loss_D: 0.4109  Loss_G: 6.3312  D(x): 0.9002    D(G(z)): 0.2153 / 0.0048
    [0/5][900/1583] Loss_D: 2.7930  Loss_G: 4.5548  D(x): 0.1428    D(G(z)): 0.0022 / 0.0240
    [0/5][950/1583] Loss_D: 0.3493  Loss_G: 5.5976  D(x): 0.8767    D(G(z)): 0.1498 / 0.0080
    [0/5][1000/1583]        Loss_D: 0.6749  Loss_G: 5.0457  D(x): 0.6349    D(G(z)): 0.0215 / 0.0194
    [0/5][1050/1583]        Loss_D: 0.4009  Loss_G: 4.5791  D(x): 0.7669    D(G(z)): 0.0484 / 0.0260
    [0/5][1100/1583]        Loss_D: 0.3453  Loss_G: 2.7277  D(x): 0.8885    D(G(z)): 0.1408 / 0.1219
    [0/5][1150/1583]        Loss_D: 0.2484  Loss_G: 5.0396  D(x): 0.8727    D(G(z)): 0.0595 / 0.0174
    [0/5][1200/1583]        Loss_D: 0.6760  Loss_G: 3.2315  D(x): 0.7052    D(G(z)): 0.1756 / 0.0688
    [0/5][1250/1583]        Loss_D: 0.5845  Loss_G: 3.1392  D(x): 0.7576    D(G(z)): 0.2018 / 0.0673
    [0/5][1300/1583]        Loss_D: 0.2762  Loss_G: 4.9311  D(x): 0.8666    D(G(z)): 0.0933 / 0.0136
    [0/5][1350/1583]        Loss_D: 0.4753  Loss_G: 4.7346  D(x): 0.8595    D(G(z)): 0.2228 / 0.0170
    [0/5][1400/1583]        Loss_D: 0.3764  Loss_G: 5.9964  D(x): 0.7758    D(G(z)): 0.0109 / 0.0098
    [0/5][1450/1583]        Loss_D: 0.4025  Loss_G: 3.8804  D(x): 0.8158    D(G(z)): 0.1413 / 0.0320
    [0/5][1500/1583]        Loss_D: 0.6678  Loss_G: 2.7302  D(x): 0.6980    D(G(z)): 0.1486 / 0.1040
    [0/5][1550/1583]        Loss_D: 0.6062  Loss_G: 3.1664  D(x): 0.7235    D(G(z)): 0.1305 / 0.0783
    [1/5][0/1583]   Loss_D: 0.6615  Loss_G: 8.0512  D(x): 0.9412    D(G(z)): 0.3797 / 0.0007
    [1/5][50/1583]  Loss_D: 0.8057  Loss_G: 2.1089  D(x): 0.5929    D(G(z)): 0.0869 / 0.1893
    [1/5][100/1583] Loss_D: 0.4206  Loss_G: 3.3245  D(x): 0.7409    D(G(z)): 0.0554 / 0.0640
    [1/5][150/1583] Loss_D: 0.6361  Loss_G: 4.0774  D(x): 0.7830    D(G(z)): 0.2605 / 0.0256
    [1/5][200/1583] Loss_D: 1.7394  Loss_G: 7.5861  D(x): 0.9685    D(G(z)): 0.7499 / 0.0014
    [1/5][250/1583] Loss_D: 0.4597  Loss_G: 3.1064  D(x): 0.7053    D(G(z)): 0.0265 / 0.0844
    [1/5][300/1583] Loss_D: 0.4190  Loss_G: 2.2869  D(x): 0.7942    D(G(z)): 0.1163 / 0.1660
    [1/5][350/1583] Loss_D: 0.4724  Loss_G: 4.3673  D(x): 0.8292    D(G(z)): 0.2106 / 0.0213
    [1/5][400/1583] Loss_D: 0.2877  Loss_G: 4.3217  D(x): 0.8823    D(G(z)): 0.1125 / 0.0225
    [1/5][450/1583] Loss_D: 0.8508  Loss_G: 0.8635  D(x): 0.5397    D(G(z)): 0.0390 / 0.5324
    [1/5][500/1583] Loss_D: 0.4317  Loss_G: 3.1585  D(x): 0.7646    D(G(z)): 0.0931 / 0.0767
    [1/5][550/1583] Loss_D: 0.8256  Loss_G: 6.1484  D(x): 0.9395    D(G(z)): 0.4563 / 0.0051
    [1/5][600/1583] Loss_D: 0.9765  Loss_G: 1.5017  D(x): 0.4807    D(G(z)): 0.0076 / 0.2843
    [1/5][650/1583] Loss_D: 1.8020  Loss_G: 8.8270  D(x): 0.9480    D(G(z)): 0.7248 / 0.0003
    [1/5][700/1583] Loss_D: 0.3680  Loss_G: 3.7401  D(x): 0.7991    D(G(z)): 0.0949 / 0.0404
    [1/5][750/1583] Loss_D: 0.5763  Loss_G: 2.0559  D(x): 0.6739    D(G(z)): 0.0851 / 0.1882
    [1/5][800/1583] Loss_D: 0.7773  Loss_G: 5.0999  D(x): 0.9399    D(G(z)): 0.4335 / 0.0142
    [1/5][850/1583] Loss_D: 0.3901  Loss_G: 3.4356  D(x): 0.8537    D(G(z)): 0.1744 / 0.0491
    [1/5][900/1583] Loss_D: 0.7268  Loss_G: 6.5356  D(x): 0.9635    D(G(z)): 0.4428 / 0.0027
    [1/5][950/1583] Loss_D: 0.4570  Loss_G: 3.8893  D(x): 0.8707    D(G(z)): 0.2376 / 0.0304
    [1/5][1000/1583]        Loss_D: 1.3551  Loss_G: 7.2447  D(x): 0.9333    D(G(z)): 0.6422 / 0.0030
    [1/5][1050/1583]        Loss_D: 0.3905  Loss_G: 3.3360  D(x): 0.8183    D(G(z)): 0.1462 / 0.0537
    [1/5][1100/1583]        Loss_D: 1.3858  Loss_G: 0.9796  D(x): 0.3336    D(G(z)): 0.0259 / 0.4584
    [1/5][1150/1583]        Loss_D: 0.5776  Loss_G: 2.6197  D(x): 0.6443    D(G(z)): 0.0532 / 0.1051
    [1/5][1200/1583]        Loss_D: 0.5647  Loss_G: 3.5713  D(x): 0.8026    D(G(z)): 0.2450 / 0.0428
    [1/5][1250/1583]        Loss_D: 0.4568  Loss_G: 3.6666  D(x): 0.8934    D(G(z)): 0.2581 / 0.0403
    [1/5][1300/1583]        Loss_D: 0.7197  Loss_G: 1.8175  D(x): 0.6211    D(G(z)): 0.1035 / 0.2184
    [1/5][1350/1583]        Loss_D: 0.5255  Loss_G: 3.2736  D(x): 0.8141    D(G(z)): 0.2233 / 0.0574
    [1/5][1400/1583]        Loss_D: 0.8241  Loss_G: 3.0776  D(x): 0.7807    D(G(z)): 0.3659 / 0.0743
    [1/5][1450/1583]        Loss_D: 0.4302  Loss_G: 3.3777  D(x): 0.9058    D(G(z)): 0.2518 / 0.0519
    [1/5][1500/1583]        Loss_D: 0.4173  Loss_G: 2.5610  D(x): 0.7916    D(G(z)): 0.1358 / 0.1058
    [1/5][1550/1583]        Loss_D: 0.7993  Loss_G: 5.1228  D(x): 0.8527    D(G(z)): 0.4162 / 0.0104
    [2/5][0/1583]   Loss_D: 0.4844  Loss_G: 2.2263  D(x): 0.7645    D(G(z)): 0.1510 / 0.1426
    [2/5][50/1583]  Loss_D: 0.6756  Loss_G: 2.4608  D(x): 0.5915    D(G(z)): 0.0657 / 0.1248
    [2/5][100/1583] Loss_D: 0.4391  Loss_G: 3.0181  D(x): 0.7901    D(G(z)): 0.1486 / 0.0744
    [2/5][150/1583] Loss_D: 0.5683  Loss_G: 1.8918  D(x): 0.7083    D(G(z)): 0.1411 / 0.1858
    [2/5][200/1583] Loss_D: 0.5932  Loss_G: 3.3342  D(x): 0.9111    D(G(z)): 0.3576 / 0.0522
    [2/5][250/1583] Loss_D: 0.7331  Loss_G: 2.3817  D(x): 0.6635    D(G(z)): 0.1665 / 0.1397
    [2/5][300/1583] Loss_D: 0.5493  Loss_G: 2.3824  D(x): 0.7491    D(G(z)): 0.1742 / 0.1196
    [2/5][350/1583] Loss_D: 0.6197  Loss_G: 1.8560  D(x): 0.6443    D(G(z)): 0.1018 / 0.1972
    [2/5][400/1583] Loss_D: 0.6172  Loss_G: 3.0777  D(x): 0.8482    D(G(z)): 0.3251 / 0.0621
    [2/5][450/1583] Loss_D: 0.5047  Loss_G: 3.2941  D(x): 0.9174    D(G(z)): 0.3116 / 0.0566
    [2/5][500/1583] Loss_D: 0.7335  Loss_G: 1.2796  D(x): 0.5676    D(G(z)): 0.0575 / 0.3470
    [2/5][550/1583] Loss_D: 0.7716  Loss_G: 1.9450  D(x): 0.5513    D(G(z)): 0.0580 / 0.1922
    [2/5][600/1583] Loss_D: 0.4425  Loss_G: 2.0531  D(x): 0.8015    D(G(z)): 0.1640 / 0.1686
    [2/5][650/1583] Loss_D: 1.0964  Loss_G: 4.4602  D(x): 0.9096    D(G(z)): 0.5833 / 0.0163
    [2/5][700/1583] Loss_D: 0.4745  Loss_G: 2.8636  D(x): 0.8492    D(G(z)): 0.2403 / 0.0770
    [2/5][750/1583] Loss_D: 0.4947  Loss_G: 3.6931  D(x): 0.8803    D(G(z)): 0.2732 / 0.0364
    [2/5][800/1583] Loss_D: 0.9355  Loss_G: 4.3906  D(x): 0.9120    D(G(z)): 0.5168 / 0.0195
    [2/5][850/1583] Loss_D: 0.9213  Loss_G: 1.6006  D(x): 0.4645    D(G(z)): 0.0339 / 0.2467
    [2/5][900/1583] Loss_D: 0.5337  Loss_G: 3.7601  D(x): 0.9101    D(G(z)): 0.3310 / 0.0314
    [2/5][950/1583] Loss_D: 1.2562  Loss_G: 4.9530  D(x): 0.9432    D(G(z)): 0.6244 / 0.0144
    [2/5][1000/1583]        Loss_D: 0.4187  Loss_G: 2.4701  D(x): 0.8454    D(G(z)): 0.1945 / 0.1129
    [2/5][1050/1583]        Loss_D: 0.5796  Loss_G: 2.3732  D(x): 0.7714    D(G(z)): 0.2253 / 0.1216
    [2/5][1100/1583]        Loss_D: 0.6325  Loss_G: 2.5824  D(x): 0.8307    D(G(z)): 0.3235 / 0.0939
    [2/5][1150/1583]        Loss_D: 0.7639  Loss_G: 3.9487  D(x): 0.9031    D(G(z)): 0.4398 / 0.0291
    [2/5][1200/1583]        Loss_D: 0.7040  Loss_G: 3.3561  D(x): 0.8073    D(G(z)): 0.3403 / 0.0500
    [2/5][1250/1583]        Loss_D: 1.0567  Loss_G: 4.7122  D(x): 0.9292    D(G(z)): 0.5656 / 0.0155
    [2/5][1300/1583]        Loss_D: 0.5431  Loss_G: 2.4260  D(x): 0.7628    D(G(z)): 0.2028 / 0.1116
    [2/5][1350/1583]        Loss_D: 0.7633  Loss_G: 4.1670  D(x): 0.9257    D(G(z)): 0.4404 / 0.0237
    [2/5][1400/1583]        Loss_D: 2.1958  Loss_G: 0.5288  D(x): 0.1539    D(G(z)): 0.0147 / 0.6404
    [2/5][1450/1583]        Loss_D: 0.6991  Loss_G: 1.8573  D(x): 0.5818    D(G(z)): 0.0621 / 0.1980
    [2/5][1500/1583]        Loss_D: 0.8286  Loss_G: 3.6899  D(x): 0.8805    D(G(z)): 0.4440 / 0.0364
    [2/5][1550/1583]        Loss_D: 0.5100  Loss_G: 2.5931  D(x): 0.7721    D(G(z)): 0.1862 / 0.0989
    [3/5][0/1583]   Loss_D: 0.7136  Loss_G: 2.6315  D(x): 0.8178    D(G(z)): 0.3462 / 0.1034
    [3/5][50/1583]  Loss_D: 0.6472  Loss_G: 2.6359  D(x): 0.7572    D(G(z)): 0.2460 / 0.0962
    [3/5][100/1583] Loss_D: 0.5211  Loss_G: 1.7793  D(x): 0.7275    D(G(z)): 0.1402 / 0.2050
    [3/5][150/1583] Loss_D: 0.9620  Loss_G: 4.0717  D(x): 0.9423    D(G(z)): 0.5500 / 0.0243
    [3/5][200/1583] Loss_D: 0.5469  Loss_G: 2.1994  D(x): 0.7581    D(G(z)): 0.1972 / 0.1359
    [3/5][250/1583] Loss_D: 0.3941  Loss_G: 2.7071  D(x): 0.7281    D(G(z)): 0.0401 / 0.0902
    [3/5][300/1583] Loss_D: 0.6482  Loss_G: 1.4858  D(x): 0.6275    D(G(z)): 0.1085 / 0.2802
    [3/5][350/1583] Loss_D: 1.2781  Loss_G: 4.7393  D(x): 0.9594    D(G(z)): 0.6587 / 0.0120
    [3/5][400/1583] Loss_D: 0.5942  Loss_G: 2.8406  D(x): 0.7861    D(G(z)): 0.2579 / 0.0784
    [3/5][450/1583] Loss_D: 0.5395  Loss_G: 1.9849  D(x): 0.6755    D(G(z)): 0.0854 / 0.1764
    [3/5][500/1583] Loss_D: 0.7941  Loss_G: 2.5871  D(x): 0.7891    D(G(z)): 0.3784 / 0.1006
    [3/5][550/1583] Loss_D: 0.6556  Loss_G: 3.9228  D(x): 0.9328    D(G(z)): 0.4053 / 0.0254
    [3/5][600/1583] Loss_D: 0.6489  Loss_G: 3.2773  D(x): 0.8385    D(G(z)): 0.3419 / 0.0490
    [3/5][650/1583] Loss_D: 0.9217  Loss_G: 1.3858  D(x): 0.4992    D(G(z)): 0.0854 / 0.3095
    [3/5][700/1583] Loss_D: 0.4947  Loss_G: 2.2791  D(x): 0.7948    D(G(z)): 0.2035 / 0.1332
    [3/5][750/1583] Loss_D: 0.9676  Loss_G: 1.6087  D(x): 0.4641    D(G(z)): 0.0363 / 0.2599
    [3/5][800/1583] Loss_D: 0.5918  Loss_G: 1.8852  D(x): 0.7019    D(G(z)): 0.1637 / 0.1948
    [3/5][850/1583] Loss_D: 0.7856  Loss_G: 3.4243  D(x): 0.8672    D(G(z)): 0.4219 / 0.0512
    [3/5][900/1583] Loss_D: 0.5023  Loss_G: 2.7348  D(x): 0.8372    D(G(z)): 0.2416 / 0.0851
    [3/5][950/1583] Loss_D: 0.9028  Loss_G: 1.8348  D(x): 0.5362    D(G(z)): 0.1219 / 0.2110
    [3/5][1000/1583]        Loss_D: 0.8118  Loss_G: 3.9327  D(x): 0.9092    D(G(z)): 0.4586 / 0.0306
    [3/5][1050/1583]        Loss_D: 0.8709  Loss_G: 3.1103  D(x): 0.8752    D(G(z)): 0.4686 / 0.0639
    [3/5][1100/1583]        Loss_D: 0.4286  Loss_G: 2.9141  D(x): 0.8379    D(G(z)): 0.1912 / 0.0741
    [3/5][1150/1583]        Loss_D: 0.6005  Loss_G: 1.8091  D(x): 0.7044    D(G(z)): 0.1727 / 0.2042
    [3/5][1200/1583]        Loss_D: 0.7432  Loss_G: 3.8108  D(x): 0.9088    D(G(z)): 0.4344 / 0.0297
    [3/5][1250/1583]        Loss_D: 0.6872  Loss_G: 1.8717  D(x): 0.7355    D(G(z)): 0.2731 / 0.1789
    [3/5][1300/1583]        Loss_D: 0.5740  Loss_G: 3.4426  D(x): 0.8874    D(G(z)): 0.3380 / 0.0422
    [3/5][1350/1583]        Loss_D: 0.5689  Loss_G: 2.0738  D(x): 0.6823    D(G(z)): 0.0966 / 0.1621
    [3/5][1400/1583]        Loss_D: 0.5023  Loss_G: 3.1107  D(x): 0.9225    D(G(z)): 0.3231 / 0.0565
    [3/5][1450/1583]        Loss_D: 0.7466  Loss_G: 3.1208  D(x): 0.8441    D(G(z)): 0.3891 / 0.0634
    [3/5][1500/1583]        Loss_D: 0.7135  Loss_G: 2.8145  D(x): 0.8924    D(G(z)): 0.4117 / 0.0765
    [3/5][1550/1583]        Loss_D: 0.7881  Loss_G: 4.0945  D(x): 0.9332    D(G(z)): 0.4717 / 0.0258
    [4/5][0/1583]   Loss_D: 0.6309  Loss_G: 2.2672  D(x): 0.7764    D(G(z)): 0.2761 / 0.1311
    [4/5][50/1583]  Loss_D: 0.8068  Loss_G: 1.4844  D(x): 0.5595    D(G(z)): 0.1015 / 0.2795
    [4/5][100/1583] Loss_D: 0.4912  Loss_G: 2.0030  D(x): 0.7526    D(G(z)): 0.1516 / 0.1674
    [4/5][150/1583] Loss_D: 3.0392  Loss_G: 0.6172  D(x): 0.0896    D(G(z)): 0.0134 / 0.6503
    [4/5][200/1583] Loss_D: 0.6768  Loss_G: 2.5170  D(x): 0.7543    D(G(z)): 0.2852 / 0.0986
    [4/5][250/1583] Loss_D: 1.2451  Loss_G: 0.9252  D(x): 0.3817    D(G(z)): 0.0554 / 0.4569
    [4/5][300/1583] Loss_D: 0.5916  Loss_G: 1.7704  D(x): 0.6588    D(G(z)): 0.1113 / 0.2144
    [4/5][350/1583] Loss_D: 1.3058  Loss_G: 0.6935  D(x): 0.3416    D(G(z)): 0.0394 / 0.5486
    [4/5][400/1583] Loss_D: 0.6206  Loss_G: 3.0787  D(x): 0.8405    D(G(z)): 0.3261 / 0.0609
    [4/5][450/1583] Loss_D: 0.5866  Loss_G: 1.4752  D(x): 0.6981    D(G(z)): 0.1565 / 0.2718
    [4/5][500/1583] Loss_D: 0.5616  Loss_G: 3.0459  D(x): 0.8869    D(G(z)): 0.3223 / 0.0650
    [4/5][550/1583] Loss_D: 0.6073  Loss_G: 3.2580  D(x): 0.7503    D(G(z)): 0.2344 / 0.0500
    [4/5][600/1583] Loss_D: 0.6905  Loss_G: 3.0939  D(x): 0.8591    D(G(z)): 0.3762 / 0.0589
    [4/5][650/1583] Loss_D: 0.5836  Loss_G: 1.7048  D(x): 0.6781    D(G(z)): 0.1227 / 0.2282
    [4/5][700/1583] Loss_D: 0.8543  Loss_G: 3.7586  D(x): 0.8876    D(G(z)): 0.4712 / 0.0337
    [4/5][750/1583] Loss_D: 0.8484  Loss_G: 2.3787  D(x): 0.6606    D(G(z)): 0.2724 / 0.1192
    [4/5][800/1583] Loss_D: 0.5562  Loss_G: 2.1677  D(x): 0.7446    D(G(z)): 0.1887 / 0.1533
    [4/5][850/1583] Loss_D: 0.7600  Loss_G: 1.4960  D(x): 0.5447    D(G(z)): 0.0559 / 0.2722
    [4/5][900/1583] Loss_D: 0.5677  Loss_G: 3.0179  D(x): 0.8308    D(G(z)): 0.2804 / 0.0664
    [4/5][950/1583] Loss_D: 0.5381  Loss_G: 2.9582  D(x): 0.7989    D(G(z)): 0.2345 / 0.0711
    [4/5][1000/1583]        Loss_D: 0.8333  Loss_G: 2.8499  D(x): 0.7720    D(G(z)): 0.3700 / 0.0786
    [4/5][1050/1583]        Loss_D: 0.5125  Loss_G: 1.8930  D(x): 0.7287    D(G(z)): 0.1387 / 0.1848
    [4/5][1100/1583]        Loss_D: 0.4527  Loss_G: 3.0039  D(x): 0.8639    D(G(z)): 0.2413 / 0.0614
    [4/5][1150/1583]        Loss_D: 0.7072  Loss_G: 0.8361  D(x): 0.5589    D(G(z)): 0.0563 / 0.4846
    [4/5][1200/1583]        Loss_D: 0.8619  Loss_G: 4.9323  D(x): 0.9385    D(G(z)): 0.4880 / 0.0112
    [4/5][1250/1583]        Loss_D: 0.6864  Loss_G: 2.4925  D(x): 0.7232    D(G(z)): 0.2431 / 0.1152
    [4/5][1300/1583]        Loss_D: 0.5835  Loss_G: 3.1599  D(x): 0.8430    D(G(z)): 0.3018 / 0.0644
    [4/5][1350/1583]        Loss_D: 0.9119  Loss_G: 4.7225  D(x): 0.9409    D(G(z)): 0.5082 / 0.0154
    [4/5][1400/1583]        Loss_D: 0.3856  Loss_G: 3.1007  D(x): 0.8980    D(G(z)): 0.2238 / 0.0584
    [4/5][1450/1583]        Loss_D: 1.3314  Loss_G: 5.1061  D(x): 0.9395    D(G(z)): 0.6621 / 0.0094
    [4/5][1500/1583]        Loss_D: 0.5882  Loss_G: 1.7242  D(x): 0.6443    D(G(z)): 0.0785 / 0.2306
    [4/5][1550/1583]        Loss_D: 0.5792  Loss_G: 2.0347  D(x): 0.7582    D(G(z)): 0.2143 / 0.1594
    

    6. 结果

    最后,让我们看看我们是如何做到的。在这里,我们将看到三个不同的结果。首先,我们将看到 D D D G G G的损失在训练过程中是如何变化的。然后,我们将可视化 G G G在每个epochfixed_noise batch上的输出。最后,我们将对比一批真实数据和一批来自 G G G的假数据。

    6.1. 损失随迭代次数的变化趋势图

    以下是 D D D G G G的损失与迭代次数的关系图。

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    

    在这里插入图片描述

    6.2. 可视化G的训练过程

    还记得我们是如何在每个训练的epoch后保存生成器的输出吗?现在,我们可以用动画来可视化 G G G的训练过程。

    #%%capture
    fig = plt.figure(figsize=(8,8))
    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
    
    HTML(ani.to_jshtml())
    

    在这里插入图片描述

    6.3. 真图 vs 假图

    最后,让我们并排对比查看一些真实图像和虚假图像。

    # Grab a batch of real images from the dataloader
    real_batch = next(iter(dataloader))
    
    # Plot the real images
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
    
    # Plot the fake images from the last epoch
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    plt.show()
    

    在这里插入图片描述

    7. 展望

    本教程到这里已经结束了,但是如果你想深入地研究和了解GAN,你可以:

    • 训练更长的时间,看看效果如何
    • 修改此模型以采用其他数据集,如果可能的话也可以更改图像的大小和模型架构
    • 这里查看其他一些很酷的GAN项目
    • 创建可产生音乐的GAN

    脚本总运行: ( 28 minutes 38.953 seconds)

    8. 原文

    https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

  • 相关阅读:
    React源码深度解析视频 某课网(完整版)
    解决VueRoter/element-ui路由报错Error: Avoided redundant navigation to current location的问题
    package-lock.json的作用
    encodeURI()和encodeURIComponent() 区别
    Webpack HMR 原理解析
    Kibana详细入门教程
    大数据可视化(万物互联)
    ES11来了,有些新特性还是值得一用的!
    Prometheus(普罗米修斯)——适合k8s和docker的监控系统
    linux安装pm2
  • 原文地址:https://www.cnblogs.com/ghgxj/p/14219051.html
Copyright © 2011-2022 走看看