zoukankan      html  css  js  c++  java
  • GAN网络原理介绍和代码

    GAN网络的整体公式:

    公式各参数介绍如下:

    X是真实地图片,而对应的标签是1。

    G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。

    D是一个二分类网络,对于给定的图片判别真假。

    D和G的参数更新方式:

    D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。

    D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。

    公式演变:

    对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好

    为了便于求导,故而加了log,变为如下:

    最后对整个batch求期望,变为如下:

    基于mnist实现的GAN网络结构对应的代码 

    import itertools
    import math
    import time
    
    import torch
    import torchvision
    import torch.nn as nn
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    from IPython import display
    from torch.autograd import Variable
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    
    train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
    
    
    class Discriminator(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(784, 1024),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.3),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )
    
        def forward(self, x):
            out = self.model(x.view(x.size(0), 784))
            out = out.view(out.size(0), -1)
            return out
    
    
    class Generator(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(100, 256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 1024),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(1024, 784),
                nn.Tanh()
            )
    
        def forward(self, x):
            x = x.view(x.size(0), -1)
            out = self.model(x)
            return out
    
    
    discriminator = Discriminator().cuda()
    generator = Generator().cuda()
    criterion = nn.BCELoss()
    lr = 0.0002
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
    
    
    def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
        discriminator.zero_grad()
        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels)
        real_score = outputs
    
        outputs = discriminator(fake_images)
        fake_loss = criterion(outputs, fake_labels)
        fake_score = outputs
    
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()
        return d_loss, real_score, fake_score
    def train_generator(generator, discriminator_outputs, real_labels):
        generator.zero_grad()
        g_loss = criterion(discriminator_outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()
        return g_loss
    
    # draw samples from the input distribution to inspect the generation on training
    num_test_samples = 16
    test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
    # create figure for plotting
    size_figure_grid = int(math.sqrt(num_test_samples))
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
    
    # set number of epochs and initialize figure counter
    num_epochs = 200
    num_batches = len(train_loader)
    num_fig = 0
    
    for epoch in range(num_epochs):
        for n, (images, _) in enumerate(train_loader):
            images = Variable(images.cuda())
            real_labels = Variable(torch.ones(images.size(0)).cuda())
    
            # Sample from generator
            noise = Variable(torch.randn(images.size(0), 100).cuda())
            fake_images = generator(noise)
            fake_labels = Variable(torch.zeros(images.size(0)).cuda())
    
            # Train the discriminator
            d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
                                                                 fake_labels)
    
            # Sample again from the generator and get output from discriminator
            noise = Variable(torch.randn(images.size(0), 100).cuda())
            fake_images = generator(noise)
            outputs = discriminator(fake_images)
    
            # Train the generator
            g_loss = train_generator(generator, outputs, real_labels)
    
            if (n + 1) % 100 == 0:
                test_images = generator(test_noise)
    
                for k in range(num_test_samples):
                    i = k // 4
                    j = k % 4
                    ax[i, j].cla()
                    ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
                display.clear_output(wait=True)
                display.display(plt.gcf())
    
                plt.savefig('results/mnist-gan-%03d.png' % num_fig)
                num_fig += 1
                print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
                      'D(x): %.2f, D(G(z)): %.2f'
                      % (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
                         real_score.data.mean(), fake_score.data.mean()))
    
    fig.close()
  • 相关阅读:
    Jquery fancybox 可以实现在当前页面弹出框
    http下载
    使用Fat 直接运行Java中自定义的main方法
    Tomcat 中配置直接请求(不带项目名请求)
    java 解压
    文件压缩到zip
    导出excel 头部文件
    NPOI 动态添加行
    ICSharpCode.SharpZipLib.Zip.ZipException: EOF in header
    elasticsearch配置
  • 原文地址:https://www.cnblogs.com/AntonioSu/p/11984053.html
Copyright © 2011-2022 走看看