zoukankan      html  css  js  c++  java
  • GAN生成对抗网络-INFOGAN原理与基本实现-可解释的生成对抗网络-06

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    代码

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import glob
    
    gpu = tf.config.experimental.list_physical_devices(device_type='GPU')
    tf.config.experimental.set_memory_growth(gpu[0], True)
    
    import tensorflow.keras.datasets.mnist as mnist
    
    (train_image, train_label), (_, _) = mnist.load_data()
    

    在这里插入图片描述

    train_image = train_image / 127.5  - 1
    
    train_image = np.expand_dims(train_image, -1)
    

    在这里插入图片描述

    dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
    

    在这里插入图片描述

    BATCH_SIZE = 256
    image_count = train_image.shape[0]
    noise_dim = 30
    con_dim = 30
    
    dataset = dataset.shuffle(image_count).batch(BATCH_SIZE)
    
    def generator_model():
        noise_seed = layers.Input(shape=((noise_dim,)))
        con_seed = layers.Input(shape=((con_dim,)))
        label = layers.Input(shape=(()))
        
        x = layers.Embedding(10, 30, input_length=1)(label)
        x = layers.Flatten()(x)
        x = layers.concatenate([noise_seed, con_seed, x])
        x = layers.Dense(3*3*128, use_bias=False)(x)
        x = layers.Reshape((3, 3, 128))(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        x = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)     #  7*7
    
        x = layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)    #   14*14
    
        x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.Activation('tanh')(x)
        
        model = tf.keras.Model(inputs=[noise_seed, con_seed, label], outputs=x)  
        
        return model
    
    def discriminator_model():
        image = tf.keras.Input(shape=((28,28,1)))
        
        x = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same', use_bias=False)(image)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Conv2D(32*2, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Conv2D(32*4, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.Dropout(0.5)(x)
        
        x = layers.Flatten()(x)
        x1 = layers.Dense(1)(x)
        x2 = layers.Dense(10)(x)
        x3 = layers.Dense(con_dim, activation='sigmoid')(x)
        
        model = tf.keras.Model(inputs=image, outputs=[x1, x2, x3])
        
        return model
    
    generator = generator_model()
    
    discriminator = discriminator_model()
    
    binary_cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    category_cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    def discriminator_loss(real_output, real_cat_out, fake_output, label, con_out, cond_in):
        real_loss = binary_cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = binary_cross_entropy(tf.zeros_like(fake_output), fake_output)
        cat_loss = category_cross_entropy(label, real_cat_out)
        con_loss = tf.reduce_mean(tf.square(con_out - cond_in))
        total_loss = real_loss + fake_loss + cat_loss + con_loss
        return total_loss
    
    def generator_loss(fake_output, fake_cat_out, label, con_out, cond_in):
        fake_loss = binary_cross_entropy(tf.ones_like(fake_output), fake_output)
        cat_loss = category_cross_entropy(label, fake_cat_out)
        con_loss = tf.reduce_mean(tf.square(con_out - cond_in))
        return fake_loss + cat_loss + con_loss
    
    generator_optimizer = tf.keras.optimizers.Adam(1e-5)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-5)
    
    @tf.function
    def train_step(images, labels):
        batchsize = labels.shape[0]
        noise = tf.random.normal([batchsize, noise_dim])
        cond = tf.random.uniform([batchsize, noise_dim])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator((noise, cond, labels), training=True)
    
            real_output, real_cat_out, _ = discriminator(images, training=True)
            fake_output, fake_cat_out, con_out = discriminator(generated_images, training=True)
            
            gen_loss = generator_loss(fake_output, fake_cat_out, labels, con_out, cond)
            disc_loss = discriminator_loss(real_output, real_cat_out, fake_output, labels, 
                                           con_out, cond)
    
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    num = 10
    noise_seed = tf.random.normal([num, noise_dim])
    cat_seed = np.random.randint(0, 10, size=(num, 1))
    print(cat_seed.T)
    

    在这里插入图片描述

    def generate_and_save_images(model, test_noise_input, test_cat_input, epoch):
        print('Epoch:', epoch+1)
      # Notice `training` is set to False.
      # This is so all layers run in inference mode (batchnorm).
        cond_seed = tf.random.uniform([num, con_dim])
        predictions = model((test_noise_input, cond_seed, test_cat_input), training=False)
        predictions = tf.squeeze(predictions)
        fig = plt.figure(figsize=(10, 1))
    
        for i in range(predictions.shape[0]):
            plt.subplot(1, 10, i+1)
            plt.imshow((predictions[i, :, :] + 1)/2, cmap='gray')
            plt.axis('off')
    
    #    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
        plt.show()
    
    def train(dataset, epochs):
        for epoch in range(epochs):
            for image_batch, label_batch in dataset:
                train_step(image_batch, label_batch)
            if epoch%10 == 0:
                generate_and_save_images(generator,
                                         noise_seed,
                                         cat_seed,
                                         epoch)
    
    
        generate_and_save_images(generator,
                                noise_seed,
                                cat_seed,
                                epoch)
    
    EPOCHS = 200
    
    train(dataset, EPOCHS)
    

    在这里插入图片描述
    在这里插入图片描述

    generator.save('generate_infogan.h5')
    
    num = 10
    noise_seed = tf.random.normal([num, noise_dim])
    cat_seed = np.arange(10).reshape(-1, 1)
    print(cat_seed.T)
    

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    深入理解yield from语法
    数据库事务并发问题,锁机制和对应的4种隔离级别
    同源策略与CORS跨域请求
    Restful 4 -- 认证组件、权限组件、频率组件、url注册器、响应器、分页器
    Restful 3 -- 序列化组件(GET/PUT/DELETE接口设计)、视图优化组件
    Restful 2 --DRF解析器,序列化组件使用(GET/POST接口设计)
    Restful 1 -- REST、DRF(View源码解读、APIView源码解读)及框架实现
    Vue(7)- vue-cookies、极验滑动验证geetest、vue-router的导航守卫
    Vue --6 router进阶、单页面应用(SPA)带来的问题
    Vue 5 -- axios、vuex
  • 原文地址:https://www.cnblogs.com/gemoumou/p/14186248.html
Copyright © 2011-2022 走看看