zoukankan      html  css  js  c++  java
  • GAN生成对抗网络-INFOGAN原理与基本实现-可解释的生成对抗网络-06

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    代码

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import glob
    
    gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
    tf.config.experimental.set_memory_growth(gpu[0], True)
    
    import tensorflow.keras.datasets.mnist as mnist
    
    (train_image, train_label), (_, _) = mnist.load_data()
    

    在这里插入图片描述

    train_image = train_image / 127.5  - 1
    
    train_image = np.expand_dims(train_image, -1)
    

    在这里插入图片描述

    dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
    

    在这里插入图片描述

    BATCH_SIZE = 256
    image_count = train_image.shape[0]
    noise_dim = 30
    con_dim = 30
    
    dataset = dataset.shuffle(image_count).batch(BATCH_SIZE)
    
    def generator_model():
        noise_seed = layers.Input(shape=((noise_dim,)))
        con_seed = layers.Input(shape=((con_dim,)))
        label = layers.Input(shape=(()))
        
        x = layers.Embedding(10, 30, input_length=1)(label)
        x = layers.Flatten()(x)
        x = layers.concatenate([noise_seed, con_seed, x])
        x = layers.Dense(3*3*128, use_bias=False)(x)
        x = layers.Reshape((3, 3, 128))(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        x = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)     #  7*7
    
        x = layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)    #   14*14
    
        x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.Activation('tanh')(x)
        
        model = tf.keras.Model(inputs=[noise_seed, con_seed, label], outputs=x)  
        
        return model
    
    def discriminator_model():
        image = tf.keras.Input(shape=((28,28,1)))
        
        x = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same', use_bias=False)(image)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Conv2D(32*2, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Conv2D(32*4, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Flatten()(x)
        x1 = layers.Dense(1)(x)
        x2 = layers.Dense(10)(x)
        x3 = layers.Dense(con_dim, activation='sigmoid')(x)
        
        model = tf.keras.Model(inputs=image, outputs=[x1, x2, x3])
        
        return model
    
    generator = generator_model()
    
    discriminator = discriminator_model()
    
    binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    category_cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    def discriminator_loss(real_output, real_cat_out, fake_output, label, con_out, cond_in):
        real_loss = binary_cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = binary_cross_entropy(tf.zeros_like(fake_output), fake_output)
        cat_loss = category_cross_entropy(label, real_cat_out)
        con_loss = tf.reduce_mean(tf.square(con_out - cond_in))
        total_loss = real_loss + fake_loss + cat_loss + con_loss
        return total_loss
    
    def generator_loss(fake_output, fake_cat_out, label, con_out, cond_in):
        fake_loss = binary_cross_entropy(tf.ones_like(fake_output), fake_output)
        cat_loss = category_cross_entropy(label, fake_cat_out)
        con_loss = tf.reduce_mean(tf.square(con_out - cond_in))
        return fake_loss + cat_loss + con_loss
    
    generator_optimizer = tf.keras.optimizers.Adam(1e-5)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-5)
    
    @tf.function
    def train_step(images, labels):
        batchsize = labels.shape[0]
        noise = tf.random.normal([batchsize, noise_dim])
        cond = tf.random.uniform([batchsize, noise_dim])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator((noise, cond, labels), training=True)
    
            real_output, real_cat_out, _ = discriminator(images, training=True)
            fake_output, fake_cat_out, con_out = discriminator(generated_images, training=True)
            
            gen_loss = generator_loss(fake_output, fake_cat_out, labels, con_out, cond)
            disc_loss = discriminator_loss(real_output, real_cat_out, fake_output, labels, 
                                           con_out, cond)
    
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    num = 10
    noise_seed = tf.random.normal([num, noise_dim])
    cat_seed = np.random.randint(0, 10, size=(num, 1))
    print(cat_seed.T)
    

    在这里插入图片描述

    def generate_and_save_images(model, test_noise_input, test_cat_input, epoch):
        print('Epoch:', epoch+1)
      # Notice `training` is set to False.
      # This is so all layers run in inference mode (batchnorm).
        cond_seed = tf.random.uniform([num, con_dim])
        predictions = model((test_noise_input, cond_seed, test_cat_input), training=False)
        predictions = tf.squeeze(predictions)
        fig = plt.figure(figsize=(10, 1))
    
        for i in range(predictions.shape[0]):
            plt.subplot(1, 10, i+1)
            plt.imshow((predictions[i, :, :] + 1)/2, cmap='gray')
            plt.axis('off')
    
    #    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
        plt.show()
    
    def train(dataset, epochs):
        for epoch in range(epochs):
            for image_batch, label_batch in dataset:
                train_step(image_batch, label_batch)
            if epoch%10 == 0:
                generate_and_save_images(generator,
                                         noise_seed,
                                         cat_seed,
                                         epoch)
    
    
        generate_and_save_images(generator,
                                noise_seed,
                                cat_seed,
                                epoch)
    
    EPOCHS = 200
    
    train(dataset, EPOCHS)
    

    在这里插入图片描述
    在这里插入图片描述

    generator.save('generate_infogan.h5')
    
    num = 10
    noise_seed = tf.random.normal([num, noise_dim])
    cat_seed = np.arange(10).reshape(-1, 1)
    print(cat_seed.T)
    

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    组原——④存储器4
    sdk和api的区别
    转载:直播测试
    生成短链接
    H5调原生
    Android Intent 启动方法和启动Action大全
    ps和top的区别
    安卓知识点
    正则基础之——捕获组(capture group)
    正则基础之——反向引用
  • 原文地址:https://www.cnblogs.com/gemoumou/p/14186248.html
Copyright © 2011-2022 走看看