zoukankan      html  css  js  c++  java
  • 生成对抗网络GAN详解与代码

    1.GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

    • G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。

    • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

    在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”

    最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

    这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。

    以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:

    (1)优化D:

    优化第一项是真是样本x输入的时候,结果越大越好;对于噪声等的输入z,生成的假样本G(z)要越小越好

    (2)优化G:

    优化生成器时和真是样本没关系,故不需要考虑;这时候只有假样本,但生成器希望假样本越逼真越好(接近1),故D(G(z)越大越好,则最小化1-D(G(z))

     2.GAN的特点:

        (1)相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式

        (2)GAN中G的梯度更新信息来自判别器D,而不是来自数据样本

    3. GAN 的优点:

        (1) GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链

        (2)相比其他所有模型, GAN可以产生更加清晰,真实的样本

        (3)GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域

        (4)相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊

        (5)相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布.换句话说,GANs是渐进一致的,但是VAE是有偏差的

        (6)GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了。

     4. GAN的缺点:

        (1)训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多

        (2)GAN不适合处理离散形式的数据,比如文本

        (3)GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)

    5.为什么GAN中的优化器不常用SGD

        (1)SGD容易震荡,容易使GAN训练不稳定,

        (2)GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的纳什均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题。

    6.训练GAN的一些技巧

    (1). 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)

    (2). 使用wassertein GAN的损失函数

    (3). 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑

    (4). 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm

    (5). 避免使用RELU和pooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数

    (6). 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率,

    (7). 给D的网络层增加高斯噪声,相当于是一种正则

    7.GAN实战

    import tensorflow as tf #导入tensorflow
    from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
    import numpy as np #导入numpy
    import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
    import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
    import os #导入os
     
        
    def xavier_init(size): #初始化参数时使用的xavier_init函数
        in_dim = size[0] 
        xavier_stddev = 1. / tf.sqrt(in_dim / 2.) #初始化标准差
        return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果
    
    X = tf.placeholder(tf.float32, shape=[None, 784]) #X表示真的样本(即真实的手写数字)
    
    D_W1 = tf.Variable(xavier_init([784, 128])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
    D_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量 
    D_W2 = tf.Variable(xavier_init([128, 1])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
    D_b2 = tf.Variable(tf.zeros(shape=[1])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量
    theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合
    
    Z = tf.placeholder(tf.float32, shape=[None, 100]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵
     
    G_W1 = tf.Variable(xavier_init([100, 128])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
    G_b1 = tf.Variable(tf.zeros(shape=[128])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量 
    G_W2 = tf.Variable(xavier_init([128, 784])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
    G_b2 = tf.Variable(tf.zeros(shape=[784])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量
    theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合
    
    def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入
        return np.random.uniform(-1., 1., size=[m, n])
    
    def generator(z): #生成器,z的维度为[N, 100]
        G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, 128]
        G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, 784]
        G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, 784] 
        return G_prob #返回G_prob
    
    def discriminator(x): #判别器,x的维度为[N, 784]
        D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, 128]
        D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, 1]
        D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, 1]
        return D_prob, D_logit #返回D_prob, D_logit
    
    G_sample = generator(Z) #取得生成器的生成结果
    D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
    D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果
    
    #对判别器对真实样本的判别结果计算误差(将结果与1比较)
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, targets=tf.ones_like(D_logit_real))) 
    #对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
    D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.zeros_like(D_logit_fake))) 
    #判别器的误差
    D_loss = D_loss_real + D_loss_fake 
    #生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.ones_like(D_logit_fake))) 
    
    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集
    
    D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
    G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器
    
    mb_size = 128 #训练的batch_size
    Z_dim = 100 #生成器输入的随机噪声的列的维度
      
    sess = tf.Session() #会话层
    sess.run(tf.initialize_all_variables()) #初始化所有可训练参数
    
    def plot(samples): #保存图片时使用的plot函数
        fig = plt.figure(figsize=(4, 4)) #初始化一个4行4列包含16张子图像的图片
        gs = gridspec.GridSpec(4, 4) #调整子图的位置
        gs.update(wspace=0.05, hspace=0.05) #置子图间的间距
        for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r') 
        return fig
    
    
    path = '/data/User/zcc/' #保存可视化结果的路径
    i = 0 #训练过程中保存的可视化结果的索引 
    for it in range(1000000): #训练100万次
        if it % 1000 == 0: #每训练1000次就保存一下结果
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
            fig = plot(samples) #通过plot函数生成可视化结果
            plt.savefig(path+'out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight') #保存可视化结果
            i += 1
            plt.close(fig)
     
        X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入)
     
        #下面是得到训练一次的结果,通过sess来run出来
        _, D_loss_curr, D_loss_real, D_loss_fake, D_loss = sess.run([D_solver, D_loss, D_loss_real, D_loss_fake, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
     
        if it % 1000 == 0: #每训练1000次输出一下结果
            print('Iter: {}'.format(it))
            print('D loss: {:.4}'. format(D_loss_curr))
            print('G_loss: {:.4}'.format(G_loss_curr))
            print()

     参考博客:

    https://blog.csdn.net/m0_37407756/article/details/75309670

    https://blog.csdn.net/jiongnima/article/details/80033169

  • 相关阅读:
    Spring 事务传播实践分析
    记一次%转义引发的血案
    Springboot+redis 整合
    SpringBoot基础梳理
    MyBatis String类型传递参数注意事项
    SpringBoot填坑系列---XML方式配置数据库
    自定义AlertView(Swift)
    iOS开发,最新判断是否是手机号的正则表达式
    iOS开发 UILabel实现自适应高宽
    iOS开发笔记--UILabel的相关属性设置
  • 原文地址:https://www.cnblogs.com/USTC-ZCC/p/11236847.html
Copyright © 2011-2022 走看看