zoukankan      html  css  js  c++  java
  • GAN

    生成式对抗网络(GAN)

    一、什么是生成式对抗网络GAN?

    在知乎上看到一个比较有趣的例子:

    女生让男生给自己拍照,可是一直不满意男生拍的照片,就对照“别人家的男朋友”拍的照片,一次次让男生去改,直到女生满意。

    在这个例子中,

    • 男生可以被看作是GAN中的生成模型(Generative Model);

    • 女生可以被看作是GAN中的判别模型(Discriminator);

    • 整个拍照的过程可以被看作是博弈式的训练过程

    • 男生(生成模型)的目的:拍出女朋友满意的照片(生成一幅和真实图片极其相似的图片)

    • 女生(判别模型)的目的:分辨男朋友拍的照片,不满意的打回去(判别生成图片与真实图片是否相似,如果不够相似,打回去)

    上述博弈过程,如果采用神经网络作为模型类型,则被称为生成式对抗网络(GAN)

    正如视频中提到的两个问题:

    • 为什么罪犯制造的假币越来越逼真?

      为什么GAN可以生成数据?

    二、GAN的详细介绍

    GAN的框架

    判别器D(Discriminator):区分真实样本和虚假样本。D是一个神经网络,经过运算后,如果是真实的图片,给出real(1);如果是假的图片,给出fake(0)

    随机噪声z:从一个先验分布(人为定义,一般是均匀分布或者正态分布)中随机采样的向量

    真实样本x:从数据库中采样的样本

    合成样本G(z):生成模型G输出的样本

    生成器G(Generator):欺骗判别器。生成虚假数据,使得判别器D能够尽可能给出高的评分。生成器不断改变自己,直到生成的很多图片能够欺骗判别器

    GAN目标函数

    训练算法:

    1.随机初始化生成器和判别器

    2.交替训练判别器D和生成器G,直到收敛

    • 步骤一:固定生成器G(不优化),训练判别器D区分真实图像与合成图像(赋予真实图像高分,赋予合成图像低分)(用监督训练二分类问题)
    • 步骤二:固定判别器D,训练生成器G欺骗判别器D(更新生成器的参数,使其合成的图片被生成器D赋予高分)(最大化问题)

    训练一个生成模型

    一个能够生成我们想要的数据的模型(图模型、函数、神经网络)

    GAN通过一个低维向量 生成器(全连接神经网络)

    cGAN生成可控的数据 生成器(全连接神经网络)

    DCGAN 生成器(卷积神经网络)

    WGAN 生成器(WGAN)重新设计目标函数,训练更稳定,生成数据质量更棒

    KL散度和JS散度

    • KL散度(Kullback-Leibler divergence)

      一种衡量两个概率分布的匹配程度的指标,又称为KL距离,相对熵

    当P(x)和Q(x)的相似度越高,KL散度越小

    KL散度主要有两个性质:

    (1)不对称性

    (2)非负性

    KL散度本质是用来衡量两个概率分布的差异一种数学计算方式;由于用到比值除法不具备对称性。

    神经网络训练时为何不用KL散度,从数学上来讲,KL散度多减了一个H(P);P代表真实分布,Q代表估计的分布

    极大似然估计等价于最小化生成数据分布和真实分布的KL散度

    • JS散度(Jensen-Shannon divergence)

      JS散度也称为JS距离,是KL散度的一种变形

    JS散度主要性质:

    (1)值域范围(JS散度的值域范围是[0,1],相同是0,相反为1)

    (2)对称性

    (3)交叉熵

    很多情况下,假设数据符合高斯分布是不合理的,数据分布是无法用公式显示的写出来的

    因此用高斯模型去拟合数据分布,我们需要一个更通用的生成模型,可以拟合任意数据分布,如下

    GAN:生成式对抗网络通过对抗训练,间接计算出散度JS,使得模型可以优化

    GAN做的事情:

    1.最大化判别器损失,等价于计算合成数据分布和真实数据分布的JS散度

    2.最小化生成器损失,等价于最小化JS散度(也就是优化生成模型 )

    三、DCGAN

    四、代码练习

    (一)GAN

    1. 通过make_moons生成双半月形的数据,同时把数据点画出来

    1. 定义生成器、判别器、优化器

      判别器中使用了sigmoid函数(可能是因为需要判别生成的图片是否是真实图片,即相当于是一个二分类的问题,因此用sigmoid函数)

      优化器选择的是adam

    2. 对抗训练

      整个对抗训练可以分为两部分:

      • 第一部分(固定生成器G,改进判别器D)
      • 第二部分(固定判别器D,改进生成器G)
    3. 修改learning_rate和batch_size

      学习率为0.0001,batch_size为50的结果:

    学习率为0.001,batch_size为250的结果:

    可以明显看出随着batch_size的增大、loss的减小,效果明显改善。

    (个人猜测:增大batch_size的值后,能够一次性处理更多的数据,从而能够更好地把握大方向,训练的波动程度更小)

    (二)CGAN(条件生成-对抗网络)

    • 对比于GAN,CGAN在生成器以及判别器上都多了一个标签作为输入
    • 生成器的输入是噪声和标签,输出是生成图
    • 判别器的输入是生成图,真实图以及标签,输出是真和假

    步骤与GAN相似,不同的是在生成器和判别器的定义中加入了10维的标签信息

    全连接判别器:

    全连接生成器:

    epoch改为100后:

    在epoch为100时,辨别器的损失为0.00030,效果不太好

    (三)DCGAN(深度卷积对抗网络)

    • 对比于GAN,在判别器和生成器中使用了卷积结构(在第二个、第三个、第四个滑动卷积层中使用BN加快网络收敛),同样添加Sigmoid激活函数

    滑动卷积判别器:

    反滑动卷积生成器:

    • 第一层:把输入线性变换成256×4×4的矩阵,并在这个基础上做反卷积
    • 第四层:不使用BN,使用tanh激活函数

    epoch为30时,结果如下:

    在epoch改为100后,效果不如epoch为30的结果(不想明白什么原因)

  • 相关阅读:
    例行性工作排程 (crontab)
    数组
    继续我们的学习。这次鸟哥讲的是LVM。。。磁盘管理 最后链接文章没有看
    htop资源管理器
    转:SSL协议详解
    转:SSL 握手协议详解
    转:Connection reset原因分析和解决方案
    使用Mybatis-Generator自动生成Dao、Model、Mapping相关文件(转)
    转:logback的使用和logback.xml详解
    转:Java logger组件:slf4j, jcl, jul, log4j, logback, log4j2
  • 原文地址:https://www.cnblogs.com/cch-EX/p/13657957.html
Copyright © 2011-2022 走看看