zoukankan      html  css  js  c++  java
  • GAN的Pytorch实现

    GAN的Pytorch实现

    版本V1

    import torch
    import torch.nn as nn
    from torchvision import transforms, datasets
    from torch import optim as optim
    import matplotlib
    matplotlib.use('AGG')#或者PDF, SVG或PS
    import matplotlib.pyplot as plt
    import time
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    
    
    batch_size = 100
    # MNIST dataset
    dataset = datasets.MNIST(root='./data/', train=True, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), download=True)
    
    # Data loader
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    print(f"Length of total dataset = {len(dataset)}, 
    Length of dataloader with having batch_size of {batch_size} = {len(dataloader)}")
    
    dataiter = iter(dataloader)
    images,labels = dataiter.next()
    print(torch.min(images),torch.max(images))
    
    class GeneratorModel(nn.Module):
        def __init__(self):
            super(GeneratorModel, self).__init__()
            input_dim = 100
            output_dim = 784
            # <----------D和G的非输出层激活函数都是LeakyReLU()函数--------->
            self.hidden_layer1 = nn.Sequential(
                nn.Linear(input_dim, 256),
                nn.LeakyReLU(0.2)
            )
    
            self.hidden_layer2 = nn.Sequential(
                nn.Linear(256, 512),
                nn.LeakyReLU(0.2)
            )
    
            self.hidden_layer3 = nn.Sequential(
                nn.Linear(512, 1024),
                nn.LeakyReLU(0.2)
            )
            # <----------G的最后一层激活函数是Tanh()函数--------->
            self.hidden_layer4 = nn.Sequential(
                nn.Linear(1024, output_dim),
                nn.Tanh()
            )
        
        def forward(self, x):
            output = self.hidden_layer1(x)
            output = self.hidden_layer2(output)
            output = self.hidden_layer3(output)
            output = self.hidden_layer4(output)
            return output.to(device)
            
    class DiscriminatorModel(nn.Module):
        def __init__(self):
            super(DiscriminatorModel, self).__init__()
            input_dim = 784
            output_dim = 1
    
            self.hidden_layer1 = nn.Sequential(
                nn.Linear(input_dim, 1024),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            )
    
            self.hidden_layer2 = nn.Sequential(
                nn.Linear(1024, 512),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            )
    
            self.hidden_layer3 = nn.Sequential(
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            )
    
            # <----------D的最后一层激活函数是Sigmoid()函数--------->
            self.hidden_layer4 = nn.Sequential(
                nn.Linear(256, output_dim),
                nn.Sigmoid()
            )
    
        def forward(self, x):
            output = self.hidden_layer1(x)
            output = self.hidden_layer2(output)
            output = self.hidden_layer3(output)
            output = self.hidden_layer4(output)
            return output.to(device)
    
    discriminator = DiscriminatorModel()
    generator = GeneratorModel()
    discriminator.to(device)
    generator.to(device)
    print(generator,"
    
    
    ",discriminator)
    
    
    # <----------交叉熵损失函数---------->
    criterion = nn.BCELoss() 
    
    # <----------Adam优化器---------->
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002) 
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
    
    
    num_epochs = 100
    batch = 100
    outputs=[]
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    for epoch_idx in range(num_epochs):
        start_time = time.time()
        for batch_idx, data_input in enumerate(dataloader):
          
            real = data_input[0].view(batch, 784).to(device) # batch_size X 784
            batch_size = data_input[1] # batch_size
    
            noise = torch.randn(batch,100).to(device)
            fake = generator(noise) # batch_size X 784
    
            disc_real = discriminator(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    
            disc_fake = discriminator(fake).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    
            # <----------D_loss是lossD_real+lossD_fake的和---------->
            lossD = (lossD_real + lossD_fake) / 2
            real_score = torch.mean(disc_real).item()
            fake_score = torch.mean(disc_fake).item()
            
            d_optimizer.zero_grad()      
            lossD.backward(retain_graph=True)
            d_optimizer.step()        
            
            gen_fake = discriminator(fake).view(-1)
              # <----------G_loss是使向D输入fake_img,输出的值向1靠近--------->
            lossG = criterion(gen_fake, torch.ones_like(gen_fake))
            
            g_optimizer.zero_grad()
            lossG.backward()
            g_optimizer.step()
            
            if ((batch_idx + 1)% 600 == 0 and (epoch_idx + 1)%10 == 0):
                print("Training Steps Completed: ", batch_idx)
                
                with torch.no_grad():
                    generated_data = fake.cpu().view(batch, 28, 28)
                    real_data = real.cpu().view(batch, 28, 28)
                    i = 0
                    j = 0
                    plt.figure(figsize=(10,2))
                    print("Real Images")
                    for x in real_data:
                        if(i>=10): break
                        plt.subplot(2,10,i+1)
                        plt.imshow(x.detach().numpy(), interpolation='nearest',cmap='gray')
                        i = i+1
                    plt.title("on "+str((epoch_idx + 1))+ "th epoch")
                    plt.show()
    
                    print("Generated Images")
                    plt.figure(figsize=(10,2))
                    for x in generated_data:
                        if(j>=10): break
                        plt.subplot(2,10,j+1)
                        plt.imshow(x.detach().numpy(), interpolation='nearest',cmap='gray')
                        j = j+1
                    plt.show()
    
        outputs.append((epoch_idx,real,fake))
        losses_g.append(lossG)
        losses_d.append(lossD)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
    
        print('Epochs [{}/{}] & Batch [{}/{}]: loss_d: {:.4f}, loss_g: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}, took time: {:.0f}s'.format(
                (epoch_idx+1), num_epochs, batch_idx+1, len(dataloader),lossD,lossG,real_score,fake_score,time.time()-start_time))
    
        if epoch_idx % 10 == 0:
            plt.plot(losses_d, '-')
            plt.plot(losses_g, '-')
            plt.xlabel('epoch')
            plt.ylabel('loss')
            plt.legend(['Discriminator', 'Generator'])
            plt.title('Losses')
            plt.savefig('Losses.jpg')
            plt.show()
            plt.close()
            
            plt.plot(real_scores, '-')
            plt.plot(fake_scores, '-')
            plt.xlabel('epoch')
            plt.ylabel('score')
            plt.legend(['Real', 'Fake'])
            plt.title('Scores')
            plt.savefig('Scores.jpg')
            plt.show()
            plt.close()
            
            
            # Save trained models
            torch.save(generator.state_dict(), 'generator.pth')
            torch.save(discriminator.state_dict(), 'discriminator.pth')
    



  • 相关阅读:
    Easy | LeetCode 108. 将有序数组转换为二叉搜索树
    Medium | LeetCode 105 | 剑指 Offer 07. 从前序与中序遍历序列构造二叉树
    Easy | LeetCode 543. 二叉树的直径
    Easy | LeetCode 235 | 剑指 Offer 68
    Easy | LeetCode 236 | 剑指 Offer 68
    Medium | LeetCode 114. 二叉树展开为链表 | 先序遍历 | 递归 | 迭代
    Medium | LeetCode 538,1038. 把二叉搜索树转换为累加树
    Medium | LeetCode 230. 二叉搜索树中第K小的元素
    Easy | 剑指 Offer 54. 二叉搜索树的第k大节点
    stl(5)vector容器
  • 原文地址:https://www.cnblogs.com/lwp-nicol/p/14906294.html
Copyright © 2011-2022 走看看