zoukankan      html  css  js  c++  java
  • 深度学习之对抗生成网络 示例

    最近再看深度学习的内容,看到了对抗生成模型,感觉比较有意思,在网上找了一个代码,运行了一下,发现确实可行。网上都说对抗生成网络运行比较耗时,而且不好收敛,不试不知道,一试验发现确实比较耗时。本文给出的代码只是网上找到的最基本的模型,亲测可行。

    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    from torchvision import datasets 
    from torchvision import transforms
    from torchvision.utils import save_image
    from torch.autograd import Variable
    
    
    def to_var(x):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)
    
    def denorm(x):
        out = (x + 1) / 2
        return out.clamp(0, 1)
    
    # Image processing 
    transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                         std=(0.5, 0.5, 0.5))])
    # MNIST dataset
    mnist = datasets.MNIST(root='./data/',
                           train=True,
                           transform=transform,
                           download=True)
    # Data loader
    data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                              batch_size=100, 
                                              shuffle=True)
    # Discriminator
    D = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1),
        nn.Sigmoid())
    
    # Generator 
    G = nn.Sequential(
        nn.Linear(64, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 784),
        nn.Tanh())
    
    if torch.cuda.is_available():
        D.cuda()
        G.cuda()
    
    # Binary cross entropy loss and optimizer
    criterion = nn.BCELoss()
    d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
    
    # Start training
    for epoch in range(200):
        for i, (images, _) in enumerate(data_loader):
            # Build mini-batch dataset
            batch_size = images.size(0)
            images = to_var(images.view(batch_size, -1))
            
            # Create the labels which are later used as input for the BCE loss
            real_labels = to_var(torch.ones(batch_size))
            fake_labels = to_var(torch.zeros(batch_size))
    
            #============= Train the discriminator =============#
            # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
            # Second term of the loss is always zero since real_labels == 1
            outputs = D(images)
            d_loss_real = criterion(outputs, real_labels)
            real_score = outputs
            
            # Compute BCELoss using fake images
            # First term of the loss is always zero since fake_labels == 0
            z = to_var(torch.randn(batch_size, 64))
            fake_images = G(z)
            outputs = D(fake_images)
            d_loss_fake = criterion(outputs, fake_labels)
            fake_score = outputs
            
            # Backprop + Optimize
            d_loss = d_loss_real + d_loss_fake
            D.zero_grad()
            d_loss.backward()
            d_optimizer.step()
            
            #=============== Train the generator ===============#
            # Compute loss with fake images
            z = to_var(torch.randn(batch_size, 64))
            fake_images = G(z)
            outputs = D(fake_images)
            
            # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
            # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
            g_loss = criterion(outputs, real_labels)
            
            # Backprop + Optimize
            D.zero_grad()
            G.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            
            if (i+1) % 300 == 0:
                print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                      'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f' 
                      %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
                        real_score.data.mean(), fake_score.data.mean()))
        
        # Save real images
        if (epoch+1) == 1:
            images = images.view(images.size(0), 1, 28, 28)
            save_image(denorm(images.data), './data/real_images.png')
        
        # Save sampled images
        fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images.data), './data/fake_images-%d.png' %(epoch+1))
    
    # Save the trained parameters 
    torch.save(G.state_dict(), './generator.pkl')
    torch.save(D.state_dict(), './discriminator.pkl')

    代码没有找到出处,故不标明出处了。

    运行效果如下。

    最后生成的图片。

     

  • 相关阅读:
    [A类会议] 国内论文检索
    [NISPA类会议] 怎样才能在NIPS 上面发论文?
    [国际A类会议] 2018最最最顶级的人工智能国际峰会汇总!CCF推荐!
    [DEFCON全球黑客大会] 针对CTF,大家都是怎么训练的?
    [DEFCON全球黑客大会] CTF(Capture The Flag)
    推荐系统之--- 评分预测问题
    推荐系统之--- 推荐系统实例
    推荐系统学习 -- 利用社交网络数据
    推荐系统学习 -- 利用上下文信息
    推荐系统学习 -- 利用用户标签数据
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/9801453.html
Copyright © 2011-2022 走看看