zoukankan      html  css  js  c++  java
  • VAEs(变分自编码)之keras实践

    VAEs最早由“Diederik P. Kingma and Max Welling, “Auto-Encoding Variational Bayes, arXiv (2013)”和“Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra, “Stochastic Backpropagation and Approximate Inference in Deep Generative Models,” arXiv (2014)”同时发现。

    原理:

    对自编码器来说,它只是将输入数据投影到隐空间中,这些数据在隐空间中的位置是离散的,因此在此空间中进行采样,解码后的输出很可能是毫无意义的。

    而对VAEs来说,它将输入数据转换成2个分布,一个是平均值的分布,一个是方差的分布(这就像高斯混合型了),添加上一些噪音,组合后,再进行解码。

    如图(网上找的,应该是论文里的,暂时没看论文)

     为什么分为2个分布?

    可以这么理解:假设均值和方差都有n个,那么编码部分相当于用n个高斯分布(每个输入是不同权重的n个分布的组合)去模拟输入。

    再通过一系列变换,转化为隐空间的若干维度,其每个维度可能具有某种意义。比如下面代码使用2维隐空间,可以看作是均值和方差维度。

    方差部分指数化,保证非负。添加噪音让隐空间更具有意义的连续性。

    然后我们从隐空间采样,由于隐空间具有意义上的连续性,那么解码后的东东就可能类似输入。

    损失loss如何定义?为什么?

    loss由2部分构成,第一部分就是解码输出与原始输入的loss,可以定义为交叉熵或者均方误差等。

    第二部分是约束项。如上图黄色框,m平方作为L2正则化项,前2项可以看做方差减去其泰勒展开,当σ趋近0时,方差也即e^σ为1。那么最小化前2项必然使得σ趋近0(求导即可知)。

    由此,这第二部分,m平方项约束使得均值为0,前2项约束使得方差为1。这样约束使得隐空间具有连续性,且强制输入数据在隐空间中的表示范围收拢。

    这样在隐空间中2个数据表示的中间,就有一种过渡区域。如果仅以第一部分约束,效果可能就和自编码器一样了,模型会过拟合。


    下面进入代码部分

    以MNIST数据集作为训练样本。

    from keras import backend as K
    
    from keras.models import Model
    
    from keras.metrics import binary_crossentropy
    
    import numpy as np
    
    from keras.layers import Conv2D,Flatten,Dense,Input,Lambda,Reshape,Conv2DTranspose,Layer
    
    from keras.datasets import mnist
    
    from keras.callbacks import EarlyStopping

    编码器使用卷积层,输出2个部分

    img_shape=(28,28,1)
    batch_size=16
    latent_dim=2
    
    input_img=Input(shape=img_shape)
    x=Conv2D(32,3,padding='same',activation='relu')(input_img)# 28,28,32
    x=Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)# 14,14,64
    x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
    x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
    # 保存Flatten之前的shape
    shape_before_flattening=K.int_shape(x)
    x=Flatten()(x)#14*14*64
    x=Dense(32,activation='relu')(x)#32
    # 将输入图像拆分为2个向量
    z_mean=Dense(latent_dim)(x)#2
    z_log_var=Dense(latent_dim)(x)

    定义采样方法

    def sampling(args):
        z_mean,z_log_var=args
    #     得到一个平均值为0,方差为1的正态分布,shape为(?,2)
        epsilon=K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0,stddev=1.)#K.shape返回仍是tensor
    #     tensor*tensor为elementwise操作
        return z_mean+K.exp(z_log_var)*epsilon
    z=Lambda(sampling)([z_mean,z_log_var])# 采样

    解码

    # 解码过程,逆操作
    decode_input=Input(K.int_shape(z)[1:])
    # np.prod表示对数组某个axis进行乘法操作,如果axis不指定,则将所有的元素乘积返回一个值
    x=Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decode_input)#14*14*64
    # 逆Flatten操作
    x=Reshape(shape_before_flattening[1:])(x)#14,14,64
    # 反卷积,strides=2将14*14变为28*28,跟Conv2D相反
    x=Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x)#28,28,32
    # 注意这里的激活函数
    x=Conv2D(1,3,padding='same',activation='sigmoid')(x)#28,28,1
    # 解码model
    decoder=Model(decode_input,x)
    # 解码后的图片数据
    z_decoded=decoder(z)

    定义loss,使用一个自定义layer实现

    class CustomVariationalLayer(Layer):
        def vae_loss(self,x,z_decoded):
            x=K.flatten(x)
            z_decoded=K.flatten(z_decoded)
    #         loss为原始输入和编码-解码后的输出比较
            xent_loss=binary_crossentropy(x,z_decoded)
    #         约束
    #         mean部分表示L2正则损失,K.exp(z_log_var)-(1+z_log_var)保证方差为1,如果不约束,网络可能偷懒
            kl_loss=5e-4*K.mean(K.exp(z_log_var)-(1+z_log_var)+K.square(z_mean),axis=-1)
            return K.mean(xent_loss+kl_loss)
    
        def call(self,inputs):
            x=inputs[0]
            z_decoded=inputs[1]
            loss=self.vae_loss(x,z_decoded)
    #         继承方法
            self.add_loss(loss,inputs=inputs)#将根据inputs计算的损失loss加到本layer
            return x #不用,但是需要返回点啥
    
    y=CustomVariationalLayer()([input_img,z_decoded])

    加载数据,定义、训练模型

    (x_train,y_train),(x_test,y_test)=mnist.load_data()
    
    x_train=x_train.astype('float32')/255.
    # 表示添加一个通道维度,通道数为1(颜色只有一种模式)
    x_train=x_train.reshape(x_train.shape+(1,))
    x_test=x_test.astype('float32')/255.
    x_test=x_test.reshape(x_test.shape+(1,))
    vae=Model(input_img,y)
    # 自定义层y里面已经包含了loss,这里不需要指定
    vae.compile(optimizer='rmsprop',loss=None)
    # 不需要标签,所以y为None,我们只需要知道一个图片的原始输入是否和编码-解码后的输出一致
    vae.fit(x=x_train,y=None,shuffle=True,epochs=10,batch_size=batch_size,validation_data=(x_test,None),callbacks=[EarlyStopping(patience=2)],verbose=2)

    测试

    import matplotlib.pyplot as plt
    from scipy.stats import norm
    
    # 潜空间中任意矢量可以解码成数字
    n = 10
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # norm.ppf([v1,v2...])表示正态分布积分值为vi时,对应的x轴坐标值xi
    grid_x = norm.ppf(np.linspace(0.05, 0.95, n))#可以看作均值
    grid_y = norm.ppf(np.linspace(0.05, 0.95, n))#方差
    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z_sample = np.array([[xi, yi]])
    #         np.tile将数组重复n次,如[1,2]->[1,2,1,2]。然后reshape到输入格式
            z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
            x_decoded = decoder.predict(z_sample, batch_size=batch_size)
    #         因为x_decoded为16个相同矢量得到的推导,取第一个就行,再将 28*28*1 reshape到 28*28
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit
    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='Greys_r')
    plt.show()

    结果如下,可以看到,图片是连续变化的。

  • 相关阅读:
    delphi xe5 android sample 中的 SimpleList 是怎样绑定的
    delphi xe5 android 关于文件大小的几个问答O(∩_∩)O~
    笠翁对韵
    delphi xe5 android 控制蓝牙
    delphi xe5 android 使用样式(风格)
    驰骋工作流引擎表单设计器-控件自动完成
    驰骋工作流引擎表单设计器-级联下拉框
    驰骋工作流引擎表单设计器-级联下拉框
    驰骋工作流引擎表单设计器-数据获取
    驰骋工作流引擎表单设计器--表单装载前数据填充
  • 原文地址:https://www.cnblogs.com/lunge-blog/p/11902060.html
Copyright © 2011-2022 走看看