zoukankan      html  css  js  c++  java
  • GAN生成对抗网络-INFOGAN原理与基本实现-可解释的生成对抗网络-06

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    代码

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import glob
    
    gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
    tf.config.experimental.set_memory_growth(gpu[0], True)
    
    import tensorflow.keras.datasets.mnist as mnist
    
    (train_image, train_label), (_, _) = mnist.load_data()
    

    在这里插入图片描述

    train_image = train_image / 127.5  - 1
    
    train_image = np.expand_dims(train_image, -1)
    

    在这里插入图片描述

    dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
    

    在这里插入图片描述

    BATCH_SIZE = 256
    image_count = train_image.shape[0]
    noise_dim = 30
    con_dim = 30
    
    dataset = dataset.shuffle(image_count).batch(BATCH_SIZE)
    
    def generator_model():
        noise_seed = layers.Input(shape=((noise_dim,)))
        con_seed = layers.Input(shape=((con_dim,)))
        label = layers.Input(shape=(()))
        
        x = layers.Embedding(10, 30, input_length=1)(label)
        x = layers.Flatten()(x)
        x = layers.concatenate([noise_seed, con_seed, x])
        x = layers.Dense(3*3*128, use_bias=False)(x)
        x = layers.Reshape((3, 3, 128))(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        x = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)     #  7*7
    
        x = layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)    #   14*14
    
        x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.Activation('tanh')(x)
        
        model = tf.keras.Model(inputs=[noise_seed, con_seed, label], outputs=x)  
        
        return model
    
    def discriminator_model():
        image = tf.keras.Input(shape=((28,28,1)))
        
        x = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same', use_bias=False)(image)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Conv2D(32*2, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Conv2D(32*4, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Flatten()(x)
        x1 = layers.Dense(1)(x)
        x2 = layers.Dense(10)(x)
        x3 = layers.Dense(con_dim, activation='sigmoid')(x)
        
        model = tf.keras.Model(inputs=image, outputs=[x1, x2, x3])
        
        return model
    
    generator = generator_model()
    
    discriminator = discriminator_model()
    
    binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    category_cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    def discriminator_loss(real_output, real_cat_out, fake_output, label, con_out, cond_in):
        real_loss = binary_cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = binary_cross_entropy(tf.zeros_like(fake_output), fake_output)
        cat_loss = category_cross_entropy(label, real_cat_out)
        con_loss = tf.reduce_mean(tf.square(con_out - cond_in))
        total_loss = real_loss + fake_loss + cat_loss + con_loss
        return total_loss
    
    def generator_loss(fake_output, fake_cat_out, label, con_out, cond_in):
        fake_loss = binary_cross_entropy(tf.ones_like(fake_output), fake_output)
        cat_loss = category_cross_entropy(label, fake_cat_out)
        con_loss = tf.reduce_mean(tf.square(con_out - cond_in))
        return fake_loss + cat_loss + con_loss
    
    generator_optimizer = tf.keras.optimizers.Adam(1e-5)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-5)
    
    @tf.function
    def train_step(images, labels):
        batchsize = labels.shape[0]
        noise = tf.random.normal([batchsize, noise_dim])
        cond = tf.random.uniform([batchsize, noise_dim])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator((noise, cond, labels), training=True)
    
            real_output, real_cat_out, _ = discriminator(images, training=True)
            fake_output, fake_cat_out, con_out = discriminator(generated_images, training=True)
            
            gen_loss = generator_loss(fake_output, fake_cat_out, labels, con_out, cond)
            disc_loss = discriminator_loss(real_output, real_cat_out, fake_output, labels, 
                                           con_out, cond)
    
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    num = 10
    noise_seed = tf.random.normal([num, noise_dim])
    cat_seed = np.random.randint(0, 10, size=(num, 1))
    print(cat_seed.T)
    

    在这里插入图片描述

    def generate_and_save_images(model, test_noise_input, test_cat_input, epoch):
        print('Epoch:', epoch+1)
      # Notice `training` is set to False.
      # This is so all layers run in inference mode (batchnorm).
        cond_seed = tf.random.uniform([num, con_dim])
        predictions = model((test_noise_input, cond_seed, test_cat_input), training=False)
        predictions = tf.squeeze(predictions)
        fig = plt.figure(figsize=(10, 1))
    
        for i in range(predictions.shape[0]):
            plt.subplot(1, 10, i+1)
            plt.imshow((predictions[i, :, :] + 1)/2, cmap='gray')
            plt.axis('off')
    
    #    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
        plt.show()
    
    def train(dataset, epochs):
        for epoch in range(epochs):
            for image_batch, label_batch in dataset:
                train_step(image_batch, label_batch)
            if epoch%10 == 0:
                generate_and_save_images(generator,
                                         noise_seed,
                                         cat_seed,
                                         epoch)
    
    
        generate_and_save_images(generator,
                                noise_seed,
                                cat_seed,
                                epoch)
    
    EPOCHS = 200
    
    train(dataset, EPOCHS)
    

    在这里插入图片描述
    在这里插入图片描述

    generator.save('generate_infogan.h5')
    
    num = 10
    noise_seed = tf.random.normal([num, noise_dim])
    cat_seed = np.arange(10).reshape(-1, 1)
    print(cat_seed.T)
    

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    新版淘淘商城_01_简介
    JavaMail之-通过邮件激活账号
    javaMail发送邮件
    JavaMail学习之一-邮件传输协议
    解决ios的safari不能自动播放audio问题(以及部分微信也不能自动播放)
    css3背景渐变色
    jq杂记
    各种“分享按钮“方法总结
    底部导航统一高度
    js 与或运算符 || && 妙用
  • 原文地址:https://www.cnblogs.com/gemoumou/p/14186248.html
Copyright © 2011-2022 走看看