zoukankan      html  css  js  c++  java
  • Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE

    这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现。

    1. 稀疏编码

    首先介绍一下“稀疏编码”这一概念。

           早期学者在黑白风景照片中可以提取到许多16*16像素的图像碎片。而这些图像碎片几乎都可由64种正交的边组合得到。而且组合出一张碎片所需的边的数目很少,即稀疏的。同时在音频中大多数声音也可由几种基本结构组合得到。这其实就是特征的稀疏表达。即使用少量的基本特征来组合更加高层抽象的特征。在神经网络中即体现出前一层是未加工的像素,而后一层就是对这些像素的非线性组合。

           有监督情况下可以利用深层卷积网络来提取特征,而自编码器就是无监督情况下根据自身的高阶特征编码自己。自编码器是输入输出相同的神经网络。其特点是利用稀疏的高阶特征来重构自己。一般而言自编码器的中间隐层节点的数量要小于输入节点的数量,即实现降维过程。因为对于少于输入节点的隐藏层来说无法将输入的全部信息保留,只能优先选择部分重要的特征,而后利用这些特征来复原。此外我们可以给隐层的权重加上L2正则,正则项惩罚因子越大,接近于0的系数越多,从而特征更加稀疏!   

           关于自编码器我们可以加入一些限制使其实现不同的功能,例如去噪自编码(Denoising AutoEncoder)。输入是加了噪声的数据,而输出是原始数据,在学习过程中,只有学到更鲁棒、更频繁的特征模式才能将噪声略去,回复原始数据。如果自编码器的隐层只有一层,那么原理类似于主成分分析PCA。

            HInton提出的DBN模型有多个隐含层,每个隐含层都是限制玻尔兹曼机RBM。DBN训练时需先对每两层间进行无监督的预训练,这一过程实为一个多层的自编码器,可以将每整个网络的权重初始化到一个理想的分布。最后通过反向传播算法调整模型权重,这个步骤会使用经过标注的信息来做监督性的分类训练。当年DBN给训练深度神经网络提供了可能性,它解决了网络过深带来的深度弥散。简言之:先用自编码器的方法进行无监督的预训练,提取特征并初始化权重,然后使用标注信息进行监督式的训练。

    2.  VAE工作流程

    先看下图:

            AE的工作其实是实现了    图片->向量->图片   这一过程。就是说给定一张图片编码后得到一个向量,然后将这一向量进行解码后就得到了原始的图片。这个解码后的图片和之前的原图一样吗?不完全一样。因为一般而言,如前所述是从低维隐层中恢复原图。但是AE另我们现在能训练任意多的图片,如果我们把这些图片的编码向量存在来,那以后就能通过这些编码向量来重构我们的图像,称之为标准自编码器。可这还不够,如果现在我随机拿出一个很离谱的向量直接另其解码,那解码出来的东西十有八九是无意义的东西。

           所以我们希望AE编码出的code符合一种分布(eg:高斯混合模型),那么我们就可以从这个高斯分布任意采样出一个code,给这个code解码那么就会生成一张原图类似的图。而这个强迫分布就是VAE与AE的不同之处了。VAE的编码器输出包括两部分:m和σ。其中e是正态分布, c为编码结果。m、e、σ、c的形状一样,都为(batch_size,latent_code_num) 。这个latent_code_num就相当于高斯混合分布的高斯数量。每个高斯都有自己的均值、方差。所以共有latent_code_num个均值、方差。

           接下来是VAE的损失函数:由两部分的和组成(bce_loss、kld_loss)。bce_loss即为binary_cross_entropy(二分类交叉熵)损失,即用于衡量原图与生成图片的像素误差。kld_loss即为KL-divergence(KL散度),用来衡量潜在变量的分布和单位高斯分布的差异。

     3. Pytorch实现 

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Sat Mar 10 20:48:03 2018
    
    @author: lps
    """
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.autograd import Variable
    from torchvision import transforms
    import torchvision.datasets as dst
    from torchvision.utils import save_image
    
    
    EPOCH = 15
    BATCH_SIZE = 64
    n = 2   # num_workers
    LATENT_CODE_NUM = 32   
    log_interval = 10
    
    
    transform=transforms.Compose([transforms.ToTensor()])
    data_train = dst.MNIST('MNIST_data/', train=True, transform=transform, download=False)
    data_test = dst.MNIST('MNIST_data/', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=data_test, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)
    
    
    class VAE(nn.Module):
          def __init__(self):
                super(VAE, self).__init__()
          
                self.encoder = nn.Sequential(
                      nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
                      nn.BatchNorm2d(64),
                      nn.LeakyReLU(0.2, inplace=True),
                      
                      nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
                      nn.BatchNorm2d(128),
                      nn.LeakyReLU(0.2, inplace=True),
                          
                      nn.Conv2d(128, 128, kernel_size=3 ,stride=1, padding=1),
                      nn.BatchNorm2d(128),
                      nn.LeakyReLU(0.2, inplace=True),                  
                      )
                
                self.fc11 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
                self.fc12 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
                self.fc2 = nn.Linear(LATENT_CODE_NUM, 128 * 7 * 7)
                
                self.decoder = nn.Sequential(                
                      nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                      nn.ReLU(inplace=True),
                      
                      nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
                      nn.Sigmoid()
                      )
    
          def reparameterize(self, mu, logvar):
                eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
                z = mu + eps * torch.exp(logvar/2)            
                
                return z
          
          def forward(self, x):
                 out1, out2 = self.encoder(x), self.encoder(x)  # batch_s, 8, 7, 7
                 mu = self.fc11(out1.view(out1.size(0),-1))     # batch_s, latent
                 logvar = self.fc12(out2.view(out2.size(0),-1)) # batch_s, latent
                 z = self.reparameterize(mu, logvar)      # batch_s, latent      
                 out3 = self.fc2(z).view(z.size(0), 128, 7, 7)    # batch_s, 8, 7, 7
                 
                 return self.decoder(out3), mu, logvar
    
    
    def loss_func(recon_x, x, mu, logvar):
          BCE = F.binary_cross_entropy(recon_x, x,  size_average=False)
          KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
          
          return BCE+KLD
    
    
    vae = VAE().cuda()
    optimizer =  optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    
    
    def train(EPOCH):
          vae.train()
          total_loss = 0
          for i, (data, _) in enumerate(train_loader, 0):
                data = Variable(data).cuda()
                optimizer.zero_grad()
                recon_x, mu, logvar = vae.forward(data)
                loss = loss_func(recon_x, data, mu, logvar)
                loss.backward()
                total_loss += loss.data[0]
                optimizer.step()
                
                if i % log_interval == 0:
                      sample = Variable(torch.randn(64, LATENT_CODE_NUM)).cuda()
                      sample = vae.decoder(vae.fc2(sample).view(64, 128, 7, 7)).cpu()
                      save_image(sample.data.view(64, 1, 28, 28),
                   'result/sample_' + str(epoch) + '.png')
                      print('Train Epoch:{} -- [{}/{} ({:.0f}%)] -- Loss:{:.6f}'.format(
                                  epoch, i*len(data), len(train_loader.dataset), 
                                  100.*i/len(train_loader), loss.data[0]/len(data)))
                      
          print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, total_loss / len(train_loader.dataset)))
          
    for epoch in range(1, EPOCH):
        train(epoch)
          
    
    main.py

    编解码器可由全连接或卷积网络实现。这里采用CNN。结果如下:

           

           

             

    参考 :

    《Tensoflow 实战》

    Pytorch tutorial

    Paper-Implementations

    /pytorch-tutorial

  • 相关阅读:
    CF 461B Appleman and Tree
    POJ 1821 Fence
    NOIP 2012 开车旅行
    CF 494B Obsessive String
    BZOJ2337 XOR和路径
    CF 24D Broken robot
    POJ 1952 BUY LOW, BUY LOWER
    SPOJ NAPTIME Naptime
    POJ 3585
    CF 453B Little Pony and Harmony Chest
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/11179901.html
Copyright © 2011-2022 走看看