zoukankan      html  css  js  c++  java
  • 理解GAN对抗神经网络的损失函数和训练过程

    GAN最不好理解的就是Loss函数的定义和训练过程,这里用一段代码来辅助理解,就能明白到底是怎么回事。其实GAN的损失函数并没有特殊之处,就是常用的binary_crossentropy,关键在于训练过程中存在两个神经网络和两个损失函数。

    np.random.seed(42)
    tf.random.set_seed(42)
    
    codings_size = 30
    
    generator = keras.models.Sequential([
        keras.layers.Dense(100, activation="selu", input_shape=[codings_size]),
        keras.layers.Dense(150, activation="selu"),
        keras.layers.Dense(28 * 28, activation="sigmoid"),
        keras.layers.Reshape([28, 28])
    ])
    discriminator = keras.models.Sequential([
        keras.layers.Flatten(input_shape=[28, 28]),
        keras.layers.Dense(150, activation="selu"),
        keras.layers.Dense(100, activation="selu"),
        keras.layers.Dense(1, activation="sigmoid")
    ])
    gan = keras.models.Sequential([generator, discriminator])
    
    discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
    discriminator.trainable = False
    gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
    
    batch_size = 32
    dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
    

    这里generator并不用compile,因为gan网络已经compile了。具体原因见下文。

    训练过程的代码如下

    def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
        generator, discriminator = gan.layers
        for epoch in range(n_epochs):
            print("Epoch {}/{}".format(epoch + 1, n_epochs))              # not shown in the book
            for X_batch in dataset:
                # phase 1 - training the discriminator
                noise = tf.random.normal(shape=[batch_size, codings_size])
                generated_images = generator(noise)
                X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
                y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
                discriminator.trainable = True
                discriminator.train_on_batch(X_fake_and_real, y1)
                # phase 2 - training the generator
                noise = tf.random.normal(shape=[batch_size, codings_size])
                y2 = tf.constant([[1.]] * batch_size)
                discriminator.trainable = False
                gan.train_on_batch(noise, y2)
            plot_multiple_images(generated_images, 8)                     # not shown
            plt.show()                                                    # not shown
    

    第一阶段(discriminator训练)

    # phase 1 - training the discriminator
    noise = tf.random.normal(shape=[batch_size, codings_size])
    generated_images = generator(noise)
    X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
    y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
    discriminator.trainable = True
    discriminator.train_on_batch(X_fake_and_real, y1)
    

    这个阶段首先生成数量相同的真实图片和假图片,concat在一起,即X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)。然后是label,真图片的label是1,假图片的label是0。

    然后是迅速阶段,首先将discrinimator设置为可训练,discriminator.trainable = True,然后开始阶段。第一个阶段的训练过程只训练discriminator,discriminator.train_on_batch(X_fake_and_real, y1),而不是整个GAN网络gan

    第二阶段(generator训练)

    # phase 2 - training the generator
    noise = tf.random.normal(shape=[batch_size, codings_size])
    y2 = tf.constant([[1.]] * batch_size)
    discriminator.trainable = False
    gan.train_on_batch(noise, y2)
    

    在第二阶段首先生成假图片,但是不再生成真图片。把假图片的label全部设置为1,并把discriminator的权重冻结,即discriminator.trainable = False。这一步很关键,应该这么理解:

    前面第一阶段的是discriminator的训练,使真图片的预测值尽量接近1,假图片的预测值尽量接近0,以此来达到优化损失函数的目的。现在将discrinimator的权重冻结,网络中输入假图片,并故意把label设置为1。

    注意,在整个gan网络中,从上向下的顺序是先通过geneartor,再通过discriminator,即gan = keras.models.Sequential([generator, discriminator])。第二个阶段将discrinimator冻结,并训练网络gan.train_on_batch(noise, y2)。如果generator生成的图片足够真实,经过discrinimator后label会尽可能接近1。由于故意把y2的label设置为1,所以如果genrator生成的图片足够真实,此时generator训练已经达到最优状态,不会大幅度更新权重;如果genrator生成的图片不够真实,经过discriminator之后,预测值会接近0,由于y2的label是1,相当于预测值不准确,这时候gan网络的损失函数较大,generator会通过更新generator的权重来降低损失函数。

    之后,重新回到第一阶段训练discriminator,然后第二阶段训练generator。假设整个GAN网络达到理想状态,这时候generator产生的假图片,经过discriminator之后,预测值应该是0.5。假如这个值小于0.5,证明generator不是特别准确,在第二阶段训练过程中,generator的权重会被继续更新。假如这个值大于0.5,证明discriminator不是特别准确,在第一阶段训练中,discriminator的权征会被继续更新。

    简单说,对于一张generator生成的假图片,discriminator会尽量把预测值拉下拉,generator会尽量把预测值往上扯,类似一个拔河的过程,最后达到均衡状态,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。

  • 相关阅读:
    网络流量监控工具iftop
    CentOS6.X安装vsftpd服务
    CentOS 6.x版本升级Mysql
    CentOS 5.x版本升级Mysql
    CentOS 5.x版本升级PHP
    CentOS 6.X版本升级PHP
    Spring bean configuration inheritance
    cannot load such file -- openssl
    第八章、Linux 磁盘与文件系统管理
    Laravel Configuration
  • 原文地址:https://www.cnblogs.com/yaos/p/14014111.html
Copyright © 2011-2022 走看看