zoukankan      html  css  js  c++  java
  • 第二十一节,条件变分自编码

    一 条件变分自编码(CVAE)

    变分自编码存在一个问题,虽然可以生成一个样本,但是只能输出与输入图片相同类别的样本。虽然也可以随机从符合模型生成的高斯分布中取数据来还原成样本,但是这样的话饿哦们并不知道生成的样本属于哪个类别。条件变分编码则可以解决这个问题,让网络按指定的类别生成样本。

    在变分自编码的基础上,再取理解条件编码自编码会很容易。主要的改动是,在训练测试时加入一个one-hot向量,用于表示标签向量。其实就是给编码自编码网络加入一个条件,让网络学习图片时加入标签因素,这样就可以按照指定的标签生成图片。 

    二 CVAE实例 

    在编码节点需要在输入端添加标签对应的特征,在解码阶段同样也需要将标签加入输入,这样,再解码的结果向原始的输入样本不断逼近,最终得到的模型会把输入的标签的特征当成MNIST数据的一部分,从而实现通过标签生成指定的图片。

     该程序在上一节程序上作了一些修改,代码如下:

    # -*- coding: utf-8 -*-
    """
    Created on Thu May 31 15:34:08 2018
    
    @author: zy
    """
    
    '''
    条件变分自编码
    '''
    
    
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    mnist = input_data.read_data_sets('MNIST-data',one_hot=True)
    
    print(type(mnist)) #<class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
    
    print('Training data shape:',mnist.train.images.shape)           #Training data shape: (55000, 784)
    print('Test data shape:',mnist.test.images.shape)                #Test data shape: (10000, 784)
    print('Validation data shape:',mnist.validation.images.shape)    #Validation data shape: (5000, 784)
    print('Training label shape:',mnist.train.labels.shape)          #Training label shape: (55000, 10)
    
    train_X = mnist.train.images
    train_Y = mnist.train.labels
    test_X = mnist.test.images
    test_Y = mnist.test.labels
    
    
    '''
    定义网络参数
    '''
    n_input = 784
    n_hidden_1 = 256
    n_hidden_2 = 2
    n_classes = 10
    learning_rate = 0.001
    training_epochs = 20               #迭代轮数
    batch_size = 128                   #小批量数量大小
    display_epoch = 3
    show_num = 10
    
    x = tf.placeholder(dtype=tf.float32,shape=[None,n_input])
    y = tf.placeholder(dtype=tf.float32,shape=[None,n_classes])
    #后面通过它输入分布数据,用来生成模拟样本数据
    zinput = tf.placeholder(dtype=tf.float32,shape=[None,n_hidden_2])
    
    
    '''
    定义学习参数
    '''
    weights = {
            'w1':tf.Variable(tf.truncated_normal([n_input,n_hidden_1],stddev = 0.001)),
            'w_lab1':tf.Variable(tf.truncated_normal([n_classes,n_hidden_1],stddev = 0.001)),
            'mean_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
            'log_sigma_w1':tf.Variable(tf.truncated_normal([n_hidden_1*2,n_hidden_2],stddev = 0.001)),
            'w2':tf.Variable(tf.truncated_normal([n_hidden_2+n_classes,n_hidden_1],stddev = 0.001)),
            'w3':tf.Variable(tf.truncated_normal([n_hidden_1,n_input],stddev = 0.001))
            }
    
    biases = {
            'b1':tf.Variable(tf.zeros([n_hidden_1])),
            'b_lab1':tf.Variable(tf.zeros([n_hidden_1])),
            'mean_b1':tf.Variable(tf.zeros([n_hidden_2])),
            'log_sigma_b1':tf.Variable(tf.zeros([n_hidden_2])),
            'b2':tf.Variable(tf.zeros([n_hidden_1])),
            'b3':tf.Variable(tf.zeros([n_input]))
            }
    
    
    '''
    定义网络结构
    '''
    #第一个全连接层是由784个维度的输入样->256个维度的输出
    h1 = tf.nn.relu(tf.add(tf.matmul(x,weights['w1']),biases['b1']))
    #输入标签
    h_lab1 = tf.nn.relu(tf.add(tf.matmul(y,weights['w_lab1']),biases['b_lab1']))
    #合并
    hall1 = tf.concat([h1,h_lab1],1)
    
    #第二个全连接层并列了两个输出网络
    z_mean = tf.add(tf.matmul(hall1,weights['mean_w1']),biases['mean_b1'])
    z_log_sigma_sq = tf.add(tf.matmul(hall1,weights['log_sigma_w1']),biases['log_sigma_b1'])
    
    
    #然后将两个输出通过一个公式的计算,输入到以一个2节点为开始的解码部分 高斯分布样本
    eps = tf.random_normal(tf.stack([tf.shape(h1)[0],n_hidden_2]),0,1,dtype=tf.float32)
    z = tf.add(z_mean,tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)),eps))
    #合并
    zall = tf.concat([z,y],1)    #None x 12
    
    
    #解码器 由12个维度的输入->256个维度的输出
    h2 = tf.nn.relu(tf.matmul(zall,weights['w2']) + biases['b2'])
    #解码器 由256个维度的输入->784个维度的输出  即还原成原始输入数据
    reconstruction = tf.matmul(h2,weights['w3']) + biases['b3']
    
    
    #这两个节点不属于训练中的结构,是为了生成指定数据时用的
    zinputall = tf.concat([zinput,y],1)
    h2out = tf.nn.relu(tf.matmul(zinputall,weights['w2']) + biases['b2'])
    reconstructionout = tf.matmul(h2out,weights['w3']) + biases['b3']
    
    '''
    构建模型的反向传播
    '''
    #计算重建loss
    #计算原始数据和重构数据之间的损失,这里除了使用平方差代价函数,也可以使用交叉熵代价函数  
    reconstr_loss = 0.5*tf.reduce_sum((reconstruction-x)**2)
    print(reconstr_loss.shape)    #(,) 标量
    #使用KL离散度的公式
    latent_loss = -0.5*tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq),1)
    print(latent_loss.shape)      #(128,)
    cost = tf.reduce_mean(reconstr_loss+latent_loss)
    
    
    #定义优化器    
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    
    num_batch = int(np.ceil(mnist.train.num_examples / batch_size))
    
    '''
    开始训练
    '''
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        print('开始训练')
        for epoch in range(training_epochs):
            total_cost = 0.0
            for i in range(num_batch):
                batch_x,batch_y = mnist.train.next_batch(batch_size)            
                _,loss = sess.run([optimizer,cost],feed_dict={x:batch_x,y:batch_y})
                total_cost += loss
                
            #打印信息
            if epoch % display_epoch == 0:
                print('Epoch {}/{}  average cost {:.9f}'.format(epoch+1,training_epochs,total_cost/num_batch))
                            
        print('训练完成')
        
        #测试
        print('Result:',cost.eval({x:mnist.test.images,y:mnist.test.labels}))
        #数据可视化   根据原始图片生成自编码数据                  
        reconstruction = sess.run(reconstruction,feed_dict = {x:mnist.test.images[:show_num],y:mnist.test.labels[:show_num]})
        plt.figure(figsize=(1.0*show_num,1*2))        
        for i in range(show_num):
            #原始图像
            plt.subplot(2,show_num,i+1)            
            plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray')   
            plt.axis('off')
               
            #变分自编码器重构图像
            plt.subplot(2,show_num,i+show_num+1)
            plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray')       
            plt.axis('off')
        plt.show()
        
    
            
        '''
        高斯分布取样,根据标签生成模拟数据
        '''        
        z_sample = np.random.randn(show_num,2)
        reconstruction = sess.run(reconstructionout,feed_dict={zinput:z_sample,y:mnist.test.labels[:show_num]})    
        plt.figure(figsize=(1.0*show_num,1*2))        
        for i in range(show_num):
            #原始图像
            plt.subplot(2,show_num,i+1)            
            plt.imshow(np.reshape(mnist.test.images[i],(28,28)),cmap='gray')   
            plt.axis('off')
               
            #根据标签成成模拟数据
            plt.subplot(2,show_num,i+show_num+1)
            plt.imshow(np.reshape(reconstruction[i],(28,28)),cmap='gray')       
            plt.axis('off')
        plt.show()
        

    上面第一幅图是根据原始图片生成的自编码数据,第一行为原始数据,第二行为自编码数据,该数据仍然保留一些原始图片的特征。

    第二幅图片是利用样本数据的标签和高斯分布之z_sample一起生成的模拟数据,我们可以看到通过标签生成的数据,已经彻底学会了样本数据的分布,并生成了与输入截然不同但带有相同意义的数据。

  • 相关阅读:
    Mina之session
    GNU C 、ANSI C、标准C、标准c++区别和联系
    SOCKET CLOSE_WAIT 搜集
    [转]二维数组和二级指针的传递问题
    Linux下C语言线程池的实现(1)
    MINA2 之日志配置
    mina里的死锁检测
    MINA2中的拆包组包的处理及一些方法
    void及void指针含义的深刻解析
    JS轻松实现单击文本框弹出选择日期
  • 原文地址:https://www.cnblogs.com/zyly/p/9123443.html
Copyright © 2011-2022 走看看