zoukankan      html  css  js  c++  java
  • tensorflow2.0——VAE实战

    两个分布的对比(这里是与标准正态分布对比):

    import tensorflow as tf
    from tensorflow import keras
    import numpy as np
    import matplotlib.pyplot as plt
    from PIL import Image
    import os
    from tensorflow.keras import Sequential, layers
    import sys
    
    #  # 设置相关底层配置
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    
    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    # os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 使用第2块gpu
    
    def my_save_img(data,name):
        save_img_path = './img_dir/VAE_img/{}.jpg'.format(name)
        new_img = np.zeros((280,280))
        for index,each_img in enumerate(data[:100]):
            row_start = int(index/10) * 28
            col_start = (index%10)*28
            # print(index,row_start,col_start)
            new_img[row_start:row_start+28,col_start:col_start+28] = each_img
    
        plt.imsave(save_img_path,new_img)
    
    #   超参数
    z_dim = 10
    h_dim = 20
    batchsz = 512
    learn_rate = 1e-3
    
    (x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
    print('x_train.shape:', x_train.shape)
    train_db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batchsz * 5).batch(batchsz)
    test_db = tf.data.Dataset.from_tensor_slices(x_test).batch(batchsz)
    
    class VAE(keras.Model):
        def __init__(self):
            super(VAE,self).__init__()
    
            #   Encoder
            self.fc1 = layers.Dense(128)
            self.fc2 = layers.Dense(z_dim)      #       获得均值
            self.fc3 = layers.Dense(z_dim)      #       获得均值
    
            #   Decoder
            self.fc4 = layers.Dense(128)
            self.fc5 = layers.Dense(784)
    
        def encoder(self,x):
            h = tf.nn.relu(self.fc1(x))
            #   get mean    获取均值
            mu = self.fc2(h)
            #   get variance    获取方差
            log_var = self.fc3(h)
    
            return mu,log_var
    
        def decoder(self,z):
    
            out = tf.nn.relu(self.fc4(z))
            out = self.fc5(out)
    
            return out
    
    
        def call(self, inputs, training=None, mask=None):
            #   [b,784] =>[b,z_dim],[b,z_dim]
            mu,log_var = self.encoder(inputs)
    
            eps = tf.random.normal(log_var.shape)
            std = tf.exp(log_var) ** 0.5
            z = mu + std * eps
    
            x_hat = self.decoder(z)
            return x_hat,mu,log_var
    
    my_model = VAE()
    # my_model.build(input_shape=(4,784))
    opt = tf.optimizers.Adam(learn_rate)
    
    for epoch in range(1000):
        for step,x in enumerate(train_db):
            x = tf.reshape(x, [-1, 784])
            with tf.GradientTape() as  tape:
                x_hat,mu,log_var = my_model(x)
    
                # rec_loss = tf.losses.binary_crossentropy(x, x_hat, from_logits=True)
                rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_hat)
                rec_loss = tf.reduce_sum(rec_loss)/x.shape[0]
    
                #   分布loss  (mu,var) - N(0,1)
                kl_div = -0.5 * (log_var + 1 - mu ** 2 - tf.exp(log_var))
                kl_div = tf.reduce_sum(kl_div) / x.shape[0]
    
                #   两个误差综合
                my_loss = rec_loss + 1. * kl_div
    
            grads = tape.gradient(my_loss, my_model.trainable_variables)
            opt.apply_gradients(zip(grads, my_model.trainable_variables))
    
            if step % 100 == 0:
                print(epoch,step,float(my_loss),'kl div:',float(kl_div),'rec loss:',float(rec_loss))
    
            #   evaluation
            #   随机Z只用decode生成
            z = tf.random.normal((batchsz,z_dim))
            logits = my_model.decoder(z)
            x_hat = tf.sigmoid(logits)
            x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
            my_save_img(x_hat,'{}_random'.format(epoch))
    
            x = next(iter(test_db))
            my_save_img(x, '{}_label'.format(epoch))
            x = tf.reshape(x, [-1, 784])
            x_hat_logits, _, _ = my_model(x)
            x_hat = tf.sigmoid(x_hat_logits)
            x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
            my_save_img(x_hat, '{}_pre'.format(epoch))
  • 相关阅读:
    java发送邮件..转
    SSHE框架整合(增删改查)
    easyui-conbotree树形下拉框。。。转
    spring和Hibernate整合
    php实现注册
    原生ajax实现登录(一部分代码)
    Apache 与 php的环境搭建
    SSH框架整合(代码加文字解释)
    数据库中树形列表(以easyui的tree为例)
    SVN源代码的版本控制系统使用简介
  • 原文地址:https://www.cnblogs.com/cxhzy/p/14164026.html
Copyright © 2011-2022 走看看