zoukankan      html  css  js  c++  java
  • pytorch-自编码器与变分自编码器-有损图像压缩

    笔记摘抄
    笔记

    import  torch
    from    torch import nn, optim
    from    torch.utils.data import DataLoader
    from    torchvision import transforms, datasets
    
    import  visdom
    

    1. 自编码器(Auto-Encoder)

    class AE(nn.Module):
    
        def __init__(self):
            super(AE, self).__init__()
    
            # [b, 784] => [b, 20]
            self.encoder = nn.Sequential(
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, 20),
                nn.ReLU()
            )
            # [b, 20] => [b, 784]
            self.decoder = nn.Sequential(
                nn.Linear(20, 64),
                nn.ReLU(),
                nn.Linear(64, 256),
                nn.ReLU(),
                nn.Linear(256, 784),
                nn.Sigmoid()
            )
    
        def forward(self, x):                 #x.shape=[b, 1, 28, 28]
    
            batchsz = x.size(0)
            x = x.view(batchsz, 784)          #flatten
            x = self.encoder(x)               #encoder [b, 20]
            x = self.decoder(x)               #decoder [b, 784]
            x = x.view(batchsz, 1, 28, 28)    #reshape [b, 1, 28, 28]
    
            return x, None
    

    2. 变分自动编码器(Variational Auto-Encoder)

    代码中的h和图中的ci,计算方法略有不同,代码中没有用指数。

    KL散度计算公式(代码中与sigma相乘的torch.randn_like(sigma)符合正态分布):

    class VAE(nn.Module):
    
        def __init__(self):
            super(VAE, self).__init__()
    
            # [b, 784] => [b, 20]
            self.encoder = nn.Sequential(
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, 20),
                nn.ReLU()
            )
            # [b, 20] => [b, 784]
            self.decoder = nn.Sequential(
                nn.Linear(10, 64),
                nn.ReLU(),
                nn.Linear(64, 256),
                nn.ReLU(),
                nn.Linear(256, 784),
                nn.Sigmoid()
            )
    
            self.criteon = nn.MSELoss()
    
        def forward(self, x):              #x.shape=[b, 1, 28, 28]
    
            batchsz = x.size(0)
            x = x.view(batchsz, 784)                 #flatten
    
            h_ = self.encoder(x)                     #encoder  [b, 20], including mean and sigma
            mu, sigma = h_.chunk(2, dim=1)           #[b, 20] => mu[b, 10] and sigma[b, 10]
               
            kld = 0.5 * torch.sum(mu**2 + sigma**2 - torch.log(1e-8 + sigma**2) - 1) / (batchsz*28*28)   # KL散度计算
            
            h = mu + sigma * torch.randn_like(sigma) # 再参数化,reparametrize trick, epison~N(0, 1)
            x_hat = self.decoder(h)                  # decoder  [b, 784]
            x_hat = x_hat.view(batchsz, 1, 28, 28)   # reshape  [b, 1, 28, 28]
    
            return x_hat, kld   # 目标:最大化证据下界ELBO
    

    3. MINIST数据集上分别调用上面的编码器

    def main():
        mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([transforms.ToTensor()]), download=True)
        mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    
        mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([transforms.ToTensor()]), download=True)
        mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
    
        x, _ = iter(mnist_train).next()    #x: torch.Size([32, 1, 28, 28]) _: torch.Size([32])
    
        model = AE()
        # model = VAE()
    
        criteon = nn.MSELoss()             #均方损失
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        print(model)
    
        viz = visdom.Visdom()
    
        for epoch in range(20):
    
            for batchidx, (x, _) in enumerate(mnist_train):
    
                x_hat, kld = model(x)
                loss = criteon(x_hat, x)        #x_hat和x的shape=[b, 1, 28, 28]
    
                if kld is not None:
                    elbo = - loss - 1.0 * kld   #elbo为证据下界
                    loss = - elbo
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            print(epoch, 'loss:', loss.item())
            # print(epoch, 'loss:', loss.item(), 'kld:', kld.item())
    
            x, _ = iter(mnist_test).next()
    
            with torch.no_grad():
                x_hat, kld = model(x)
            viz.images(x, nrow=8, win='x', opts=dict(title='x'))
            viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
    
    
    if __name__ == '__main__':
        main()
    

    开启监听进程: python -m visdom.server

    访问:http://localhost:8097

    当调用AE时:

    当调用VAE时:

  • 相关阅读:
    POJ 1330:Nearest Common Ancestors【lca】
    图论中一类问题的总结 :必须边(点) 可行边(点)
    POJ 1486 Sorting Slides【二分图匹配】
    POJ 2375 Cow Ski Area【tarjan】
    Unity打开AppStore进行评论
    Unity3D UGUI不规则图片点击事件处理
    Unity3D之聊天框怎么跟随内容大小而变换
    Unity3D之小物体层消隐技术
    Unity3D之新手引导(责任链模式)
    Unity3D之FSM有限状态机
  • 原文地址:https://www.cnblogs.com/douzujun/p/13632068.html
Copyright © 2011-2022 走看看