zoukankan      html  css  js  c++  java
  • 变分自编码器VAE的由来和简单实现(PyTorch)

    变分自编码器VAE的由来和简单实现(PyTorch)

    ​ 之前经常遇到变分自编码器的概念((VAE)),但是自己对于这个概念总是模模糊糊,今天就系统的对(VAE)进行一些整理和回顾。

    VAE的由来

    ​ 假设有一个目标数据(X={X_1,X_2,cdots,X_n}),我们想生成一些数据,即生成(hat{X}={hat{X_1},hat{X_2},cdots,hat{X_n}}),其分布与(X)相同。

    ​ 但是实际上,这样存在一些问题,第一是我们如何将生成的(hat{X})(X)一一对应,这就需要我们采用更为精巧的度量方式,即如何度量两个分布之间的距离;第二是我们如何生成新的(hat{X}),按照朴素的想法,我们可以构造一个函数(G),使得(hat{X}=G(Z)) ,如果能构造出这个(G),我们就可以通过一个任意的(Z),来生成(hat{X}) ,而这里的(Z),可以取一个已知的分布,比如正态分布。

    目前的问题

    ​ 目前的问题转化为了如何构造(G),以及如何检验我们生成的(hat{X})是否和(X)具有同分布。在(GAN)中,这里的(G)和分布的相似度衡量都用神经网络搞定了,一个叫做(generator),一个叫做(discriminator),这二者互相拮抗,最终使得分布越来接近。

    ​ 而在我们目前的问题中,(VAE)提供了另外一种思路,沿着AutoEncoder的想法,AutoEncoder是通过(encoder)把image (a)编码为vector,叫做(latent{ }represention) ,再通过(decoder)(latent{ }space)转为(hat{a}) ,(hat{a})(a)的重建图像。

    ​ 但是AE针对每张图片生成的(latent{ }code)并没有可解释性,即sample两个(latent{ } code)之间的点输入(decoder),得到的结果并不一定具有跟这两个(latent code)相关的特征。为了解决这个问题,提出了VAE:不再采用vector来建模一个(latent{ }code),而是利用一个带有noise的高斯分布来表示。直观的理解,在加入noise之后,就有机会将训练时候train的(latent{ }code)在其latent space下赋予一定的变化能力,使latent space变得更加连续,从而可以在其中采样从而生成新的图片。

    ​ 我们之前生成的(Z={Z_1,Z_2,cdots,Z_n}),现在不再单单生成一个(Z),而是生成两个vector,分别记为(M={{mu_1},{mu_2},cdots\,{mu_n}}),(Sigma={ {sigma_1},{sigma_2},cdots,sigma_n}),分别代表新生成latent code的高斯分布的均值和方差。在sample的时候就只需要根据从标准正态分布(mathcal{N}(0,1))中采样一个(e_i),(e_i)来自于(E={e_1,e_2,cdots,e_n}),然后利用(c_i=e_i*exp({sigma_i})+mu_i)((reparameterization{ }trick)),就得到了我们所需的(c_i)(c_i)即组成我们需要的(Z)=({c_1,c_2,cdots,c_n})

    ​ 这里一方面希望(VAE)能够生成尽可能丰富的数据,因此训练的时候希望在高斯分布中含有噪声。另一方面优化的过程中会趋向于使图像质量更好,因此当噪声为0的时候退化为普通的(AutoEncoder),这种情况我们是不希望出现的。为了平衡这种trade-off,这里希望每个(p(Z|X))能够接近标准正态分布,但是另一方面网络又趋于使输入和输出图像更为接近,因此会使正态分布的方差向0的方向优化。经过这种对抗过程,最终就能产生具有一定可解释性的(decoder),同时最终得到的(Z)的分布也会趋向于(mathcal{N}(0,1)),可以表示为:

    ​ $$p(Z)=sum_{X} p(Z mid X) p(X)=sum_{X} mathcal{N}(0, 1) p(X)=mathcal{N}(0, I) sum_{X} p(X)=mathcal{N}(0, 1)$$

    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
    
            self.fc1 = nn.Linear(784, 400)
            self.fc21 = nn.Linear(400, 20)
            self.fc22 = nn.Linear(400, 20)
            self.fc3 = nn.Linear(20, 400)
            self.fc4 = nn.Linear(400, 784)
    
        def encode(self, x):
            h1 = F.relu(self.fc1(x))
            return self.fc21(h1), self.fc22(h1)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return mu + eps*std
    
        def decode(self, z):
            h3 = F.relu(self.fc3(z))
            return torch.sigmoid(self.fc4(h3))
    
        def forward(self, x):
            mu, logvar = self.encode(x.view(-1, 784))
            z = self.reparameterize(mu, logvar)
            return self.decode(z), mu, logvar
        
        def loss_function_original(recon_x, x, mu, logvar):
            BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            return BCE + KLD
    

    ​ 这里的loss由两部分组成,一部分是重建loss,一部分是使各个高斯分布趋近于标准高斯分布的loss(由KL散度推导得到)。

  • 相关阅读:
    小程序 中 自定义单选框样式
    自定义checkbox和radio的样式
    postgre修改字段数据类型
    xpat如何获取一个标签里的属性值
    一看就明白的爬虫入门讲解:基础理论篇
    第一天
    poj 1979 走多少个‘ . '问题 dfs算法
    ACM 贪心算法总结
    poj 2393 奶牛场生产成本问题 贪心算法
    poj 3262 牛毁坏花问题 贪心算法
  • 原文地址:https://www.cnblogs.com/upuphe/p/15495818.html
Copyright © 2011-2022 走看看