zoukankan      html  css  js  c++  java
  • GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

    在这里插入图片描述
    在这里插入图片描述
    CycleGAN的原理可以概述为:
    将一类图片转换成另一类图片 。也就是说,现在有两个样
    本空间,X和Y,我们希望把X空间中的样本转换成Y空间中
    的样本。(获取一个数据集的特征,并转化成另一个数据
    集的特征)
    这样来看:实际的目标就是学习从X到Y的映射。我们设这
    个映射为F。它就对应着GAN中的 生成器 ,F可以将X中的
    图片x转换为Y中的图片F(x)。对于生成的图片,我们还需要
    GAN中的 判别器 来判别它是否为真实图片,由此构成对抗
    生成网络
    在这里插入图片描述
    在足够大的样本容量下,网络可以将相同的输入图像集合
    映射到目标域中图像的任何随机排列,其中任何学习的映
    射可以归纳出与目标分布匹配的输出分布(即:映射F完全
    可以将所有x都映射为Y空间中的同一张图片,使损失无效
    化)。
    因此,单独的对抗损失Loss不能保证学习函数可以
    将单个输入Xi映射到期望的输出Yi。
    对此,作者又提出了所谓的“循环一致性损失”
    (cycle consistency loss)。
    在这里插入图片描述
    我们希望能够把 domain A 的图片(命名为 a)转
    化为 domain B 的图片(命名为图片 b)。
    为了实现这个过程,我们需要两个生成器 G_AB 和
    G_BA,分别把 domain A 和 domain B 的图片进行
    互相转换。
    将X的图片转换到Y空间后,应该还可以转换回来。
    这样就杜绝模型把所有X的图片都转换为Y空间中的
    同一张图片了
    最后为了训练这个单向 GAN 需要两个 loss,分别是
    生成器的重建 loss 和判别器的判别 loss。
    判别 loss:判别器 D_B 是用来判断输入的图片是否
    是真实的 domain B 图片
    在这里插入图片描述
    在这里插入图片描述
    CycleGAN 其实就是一个 A→B 单向 GAN 加上一个
    B→A 单向 GAN。两个 GAN 共享两个生成器,然
    后各自带一个判别器,所以加起来总共有两个判别器
    和两个生成器。
    一个单向 GAN 有两个 loss,而 CycleGAN 加起来
    总共有四个 loss。
    在这里插入图片描述
    在这里插入图片描述
    对颜色、纹理等的转换效果比较好,对多样性高的、
    多变的转换效果不好(如几何转换)

    代码

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    import tensorflow as tf
    import glob
    from matplotlib import pyplot as plt
    %matplotlib inline
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    import os
    
    os.listdir('../input/apple2orange/apple2orange')
    

    在这里插入图片描述

    imgs_A = glob.glob('../input/apple2orange/apple2orange/trainA/*.jpg')
    

    在这里插入图片描述

    imgs_B = glob.glob('../input/apple2orange/apple2orange/trainB/*.jpg')
    

    在这里插入图片描述

    test_A = glob.glob('../input/apple2orange/apple2orange/testA/*.jpg')
    test_B = glob.glob('../input/apple2orange/apple2orange/testB/*.jpg')
    

    在这里插入图片描述

    def read_jpg(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        return img
    
    def normalize(input_image):
        input_image = tf.cast(input_image, tf.float32)/127.5 - 1
        return input_image
    
    def load_image(image_path):
        image = read_jpg(image_path)
        image = tf.image.resize(image, (256, 256))
        image = normalize(image)
        return image
    
    train_a = tf.data.Dataset.from_tensor_slices(imgs_A)
    train_b = tf.data.Dataset.from_tensor_slices(imgs_B)
    test_a = tf.data.Dataset.from_tensor_slices(test_A)
    test_b = tf.data.Dataset.from_tensor_slices(test_B)
    
    BUFFER_SIZE = 200
    
    train_a = train_a.map(load_image, 
                          num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
    train_b = train_b.map(load_image, 
                          num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
    test_a = test_a.map(load_image, 
                          num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
    test_b = test_b.map(load_image, 
                          num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
    
    data_train = tf.data.Dataset.zip((train_a, train_b))
    data_test = tf.data.Dataset.zip((test_a, test_b))
    
    plt.figure(figsize=(6, 3))
    for img, musk in zip(train_a.take(1), train_b.take(1)):
        plt.subplot(1,2,1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))
        plt.subplot(1,2,2)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(musk[0]))
    

    在这里插入图片描述
    实例归一化

    !pip install tensorflow_addons
    
    import tensorflow_addons as tfa
    
    OUTPUT_CHANNELS = 3
    
    def downsample(filters, size, apply_batchnorm=True):
    #    initializer = tf.random_normal_initializer(0., 0.02)
    
        result = tf.keras.Sequential()
        result.add(
            tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                                   use_bias=False))
    
        if apply_batchnorm:
            result.add(tfa.layers.InstanceNormalization())
    
            result.add(tf.keras.layers.LeakyReLU())
    
        return result
    
    def upsample(filters, size, apply_dropout=False):
    #    initializer = tf.random_normal_initializer(0., 0.02)
    
        result = tf.keras.Sequential()
        result.add(
            tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                            padding='same',
                                            use_bias=False))
    
        result.add(tfa.layers.InstanceNormalization())
    
        if apply_dropout:
            result.add(tf.keras.layers.Dropout(0.5))
    
        result.add(tf.keras.layers.ReLU())
    
        return result
    
    def Generator():
        inputs = tf.keras.layers.Input(shape=[256,256,3])
    
        down_stack = [
            downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
            downsample(128, 4), # (bs, 64, 64, 128)
            downsample(256, 4), # (bs, 32, 32, 256)
            downsample(512, 4), # (bs, 16, 16, 512)
            downsample(512, 4), # (bs, 8, 8, 512)
            downsample(512, 4), # (bs, 4, 4, 512)
            downsample(512, 4), # (bs, 2, 2, 512)
            downsample(512, 4), # (bs, 1, 1, 512)
        ]
    
        up_stack = [
            upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
            upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
            upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
            upsample(512, 4), # (bs, 16, 16, 1024)
            upsample(256, 4), # (bs, 32, 32, 512)
            upsample(128, 4), # (bs, 64, 64, 256)
            upsample(64, 4), # (bs, 128, 128, 128)
        ]
    
    #    initializer = tf.random_normal_initializer(0., 0.02)
        last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                             strides=2,
                                             padding='same',
                                             activation='tanh') # (bs, 256, 256, 3)
    
        x = inputs
    
        # Downsampling through the model
        skips = []
        for down in down_stack:
            x = down(x)
            skips.append(x)
    
        skips = reversed(skips[:-1])
    
        # Upsampling and establishing the skip connections
        for up, skip in zip(up_stack, skips):
            x = up(x)
            x = tf.keras.layers.Concatenate()([x, skip])
    
        x = last(x)
    
        return tf.keras.Model(inputs=inputs, outputs=x)
    
    generator_x = Generator()   # a——>o
    generator_y = Generator()   # o——>a
    #tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
    
    def Discriminator():
    #    initializer = tf.random_normal_initializer(0., 0.02)
    
        inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    
        down1 = downsample(64, 4, False)(inp) # (bs, 128, 128, 64)
        down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
        down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
    
        zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
        conv = tf.keras.layers.Conv2D(
                   512, 4, strides=1,use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)
    
        norm1 = tfa.layers.InstanceNormalization()(conv)
    
        leaky_relu = tf.keras.layers.LeakyReLU()(norm1)
    
        zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)
    
        last = tf.keras.layers.Conv2D(
                   1, 4, strides=1)(zero_pad2)  # (bs, 30, 30, 1)
    
        return tf.keras.Model(inputs=inp, outputs=last)
    
    discriminator_x = Discriminator()   # discriminator  a
    discriminator_y = Discriminator()   # discriminator  o
    #tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)
    
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    def discriminator_loss(disc_real_output, disc_generated_output):
        real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
        generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss
    
    def generator_loss(disc_generated_output):
        gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
        return gan_loss
    
    LAMBDA = 7
    
    def calc_cycle_loss(real_image, cycled_image):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
        return LAMBDA * loss1
    
    generator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    generator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    
    def generate_images(model, test_input):
        prediction = model(test_input, training=True)
        plt.figure(figsize=(15,15))
    
        display_list = [test_input[0], prediction[0]]
        title = ['Input Image', 'Predicted Image']
    
        for i in range(2):
            plt.subplot(1, 2, i+1)
            plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
            plt.imshow(display_list[i] * 0.5 + 0.5)
            plt.axis('off')
        plt.show()
    
    @tf.function
    def train_step(image_a, image_b):
        with tf.GradientTape(persistent=True) as tape:
            fake_b = generator_x(image_a, training=True)
            cycled_a = generator_y(fake_b, training=True)
    
            fake_a = generator_y(image_b, training=True)
            cycled_b = generator_x(fake_a, training=True)
            
            disc_real_a = discriminator_x(image_a, training=True)
            disc_real_b = discriminator_y(image_b, training=True)
    
            disc_fake_a = discriminator_x(fake_a, training=True)
            disc_fake_b = discriminator_y(fake_b, training=True)
            
            gen_x_loss = generator_loss(disc_fake_b)
            gen_y_loss = generator_loss(disc_fake_a)
        
            total_cycle_loss = (calc_cycle_loss(image_a, cycled_a) 
                                   + calc_cycle_loss(image_b, cycled_b))
        
            # 总生成器损失 = 对抗性损失 + 循环损失。
            total_gen_x_loss = gen_x_loss + total_cycle_loss
            total_gen_y_loss = gen_y_loss + total_cycle_loss
    
            disc_x_loss = discriminator_loss(disc_real_a, disc_fake_a)
            disc_y_loss = discriminator_loss(disc_real_b, disc_fake_b)
      
        # 计算生成器和判别器损失。
        generator_x_gradients = tape.gradient(total_gen_x_loss, 
                                            generator_x.trainable_variables)
        generator_y_gradients = tape.gradient(total_gen_y_loss, 
                                            generator_y.trainable_variables)
      
        discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                                discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                                discriminator_y.trainable_variables)
        
        # 将梯度应用于优化器。
        generator_x_optimizer.apply_gradients(zip(generator_x_gradients, 
                                                  generator_x.trainable_variables))
    
        generator_y_optimizer.apply_gradients(zip(generator_y_gradients, 
                                                  generator_y.trainable_variables))
      
        discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                      discriminator_x.trainable_variables))
      
        discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                      discriminator_y.trainable_variables))
    
    def fit(train_ds, test_ds, epochs):
        for epoch in range(epochs+1):
            for img_a, img_b in train_ds:
                train_step(img_a, img_b)
            print ('.', end='')
    
            if epoch % 5 == 0:
                print()
                for test_a, test_b in test_ds.take(1):
                    print("Epoch: ", epoch)
                    generate_images(generator_x, test_a)
        generate_images(generator_x, test_a)
    
    EPOCHS = 100
    
    fit(data_train, data_test,  EPOCHS)
    

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    内存映射mmap的几个api及其使用
    hiredis的安装
    Linux 下解压大全
    redis内存数据库C客户端hiredis API 中文说明
    C/C++使用MySQL
    搜索引擎的缓存(cache)机制
    快速排序(QuickSort)
    冒泡排序
    spring核心之AOP学习总结一
    Spring学习总结六——SpringMVC一
  • 原文地址:https://www.cnblogs.com/gemoumou/p/14186244.html
Copyright © 2011-2022 走看看