zoukankan      html  css  js  c++  java
  • tensorflow 2.0 学习 (十六)生成对抗网络 GAN网络与WGAN网络

    DCGAN网络的结构:

     代码包括:

    数据:

      1 import tensorflow as tf
      2 import multiprocessing
      3 
      4 
      5 def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
      6     @tf.function
      7     def _map_fn(img):
      8         img = tf.image.resize(img, [resize, resize])
      9         img = tf.clip_by_value(img, 0, 255)
     10         img = img / 127.5 - 1
     11 
     12         return img
     13 
     14     dataset = disk_image_batch_dataset(img_paths, batch_size, drop_remainder=drop_remainder,
     15                                        map_fn=_map_fn, shuffle=shuffle, repeat=repeat)
     16     img_shape = (resize, resize, 3)
     17     len_dataset = len(img_paths) // batch_size
     18 
     19     return dataset, img_shape, len_dataset
     20 
     21 
     22 def batch_dataset(dataset,
     23                   batch_size,
     24                   drop_remainder=True,
     25                   n_prefetch_batch=1,
     26                   filter_fn=None,
     27                   map_fn=None,
     28                   n_map_threads=None,
     29                   filter_after_map=False,
     30                   shuffle=True,
     31                   shuffle_buffer_size=None,
     32                   repeat=None):
     33     # set defaults
     34     if n_map_threads is None:
     35         n_map_threads = multiprocessing.cpu_count()
     36 
     37     if shuffle and shuffle_buffer_size is None:
     38         shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048
     39 
     40     # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
     41     if shuffle:
     42         dataset = dataset.shuffle(shuffle_buffer_size)
     43 
     44     if not filter_after_map:
     45         if filter_fn:
     46             dataset = dataset.filter(filter_fn)
     47 
     48         if map_fn:
     49             dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
     50 
     51     else:  # [*] this is slower
     52         if map_fn:
     53             dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
     54 
     55         if filter_fn:
     56             dataset = dataset.filter(filter_fn)
     57 
     58     dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
     59     dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
     60 
     61     return dataset
     62 
     63 
     64 def memory_data_batch_dataset(memory_data,
     65                               batch_size,
     66                               drop_remainder=True,
     67                               n_prefetch_batch=1,
     68                               filter_fn=None,
     69                               map_fn=None,
     70                               n_map_threads=None,
     71                               filter_after_map=False,
     72                               shuffle=True,
     73                               shuffle_buffer_size=None,
     74                               repeat=None):
     75     """Batch dataset of memory data.
     76     Parameters
     77     ----------
     78     memory_data : nested structure of tensors/ndarrays/lists
     79     """
     80 
     81     dataset = tf.data.Dataset.from_tensor_slices(memory_data)
     82     dataset = batch_dataset(dataset, batch_size,
     83                             drop_remainder=drop_remainder,
     84                             n_prefetch_batch=n_prefetch_batch,
     85                             filter_fn=filter_fn,
     86                             map_fn=map_fn,
     87                             n_map_threads=n_map_threads,
     88                             filter_after_map=filter_after_map,
     89                             shuffle=shuffle,
     90                             shuffle_buffer_size=shuffle_buffer_size,
     91                             repeat=repeat)
     92 
     93     return dataset
     94 
     95 
     96 def disk_image_batch_dataset(img_paths,
     97                              batch_size,
     98                              labels=None,
     99                              drop_remainder=True,
    100                              n_prefetch_batch=1,
    101                              filter_fn=None,
    102                              map_fn=None,
    103                              n_map_threads=None,
    104                              filter_after_map=False,
    105                              shuffle=True,
    106                              shuffle_buffer_size=None,
    107                              repeat=None):
    108     """Batch dataset of disk image for PNG and JPEG.
    109     Parameters
    110     ----------
    111         img_paths : 1d-tensor/ndarray/list of str
    112         labels : nested structure of tensors/ndarrays/lists
    113     """
    114 
    115     if labels is None:
    116         memory_data = img_paths
    117 
    118     else:
    119         memory_data = (img_paths, labels)
    120 
    121     def parse_fn(path, *label):
    122         img = tf.io.read_file(path)
    123         img = tf.image.decode_png(img, 3)  # fix channels to 3
    124         return (img,) + label
    125 
    126     if map_fn:  # fuse `map_fn` and `parse_fn`
    127         def map_fn_(*args):
    128             return map_fn(*parse_fn(*args))
    129     else:
    130         map_fn_ = parse_fn
    131 
    132     dataset = memory_data_batch_dataset(memory_data,
    133                                         batch_size,
    134                                         drop_remainder=drop_remainder,
    135                                         n_prefetch_batch=n_prefetch_batch,
    136                                         filter_fn=filter_fn,
    137                                         map_fn=map_fn_,
    138                                         n_map_threads=n_map_threads,
    139                                         filter_after_map=filter_after_map,
    140                                         shuffle=shuffle,
    141                                         shuffle_buffer_size=shuffle_buffer_size,
    142                                         repeat=repeat)
    143 
    144     return dataset

    GAN:

     1 import tensorflow as tf
     2 from tensorflow.keras import layers, Model
     3 
     4 
     5 class Generator(Model):
     6     # 生成器网络类
     7     def __init__(self):
     8         super(Generator, self).__init__()
     9         filter = 64
    10         # 转置卷积层1,输出channel 为filter*8,核大小4,步长1,不使用padding,不使用偏置
    11         self.conv1 = layers.Conv2DTranspose(filter*8, 4,1, 'valid', use_bias=False)
    12         self.bn1 = layers.BatchNormalization()
    13         # 转置卷积层2
    14         self.conv2 = layers.Conv2DTranspose(filter * 4, 4, 2, 'same', use_bias=False)
    15         self.bn2 = layers.BatchNormalization()
    16         # 转置卷积层3
    17         self.conv3 = layers.Conv2DTranspose(filter * 2, 4, 2, 'same', use_bias=False)
    18         self.bn3 = layers.BatchNormalization()
    19         # 转置卷积层4
    20         self.conv4 = layers.Conv2DTranspose(filter * 1, 4, 2, 'same', use_bias=False)
    21         self.bn4 = layers.BatchNormalization()
    22         # 转置卷积层5
    23         self.conv5 = layers.Conv2DTranspose(3, 4, 2, 'same', use_bias=False)
    24 
    25     def call(self, inputs, training=None):
    26         x = inputs  # [z, 100]
    27         # Reshape 乘4D 张量,方便后续转置卷积运算:(b, 1, 1, 100)
    28         x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
    29         x = tf.nn.relu(x)  # 激活函数
    30         # 转置卷积-BN-激活函数:(b, 4, 4, 512)
    31         x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
    32         # 转置卷积-BN-激活函数:(b, 8, 8, 256)
    33         x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
    34         # 转置卷积-BN-激活函数:(b, 16, 16, 128)
    35         x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
    36         # 转置卷积-BN-激活函数:(b, 32, 32, 64)
    37         x = tf.nn.relu(self.bn4(self.conv4(x), training=training))
    38         # 转置卷积-激活函数:(b, 64, 64, 3)
    39         x = self.conv5(x)
    40         x = tf.tanh(x)  # 输出x 范围-1~1,与预处理一致
    41 
    42         return x
    43 
    44 
    45 class Discriminator(Model):
    46     # 判别器类
    47     def __init__(self):
    48         super(Discriminator, self).__init__()
    49         filter = 64
    50         # 卷积层1
    51         self.conv1 = layers.Conv2D(filter, 4, 2, 'valid', use_bias=False)
    52         self.bn1 = layers.BatchNormalization()
    53         # 卷积层2
    54         self.conv2 = layers.Conv2D(filter*2, 4, 2, 'valid', use_bias=False)
    55         self.bn2 = layers.BatchNormalization()
    56         # 卷积层3
    57         self.conv3 = layers.Conv2D(filter * 4, 4, 2, 'valid', use_bias=False)
    58         self.bn3 = layers.BatchNormalization()
    59         # 卷积层4
    60         self.conv4 = layers.Conv2D(filter * 8, 3, 1, 'valid', use_bias=False)
    61         self.bn4 = layers.BatchNormalization()
    62         # 卷积层5
    63         self.conv5 = layers.Conv2D(filter * 16, 3, 1, 'valid', use_bias=False)
    64         self.bn5 = layers.BatchNormalization()
    65         # 全局池化层
    66         self.pool = layers.GlobalAveragePooling2D()
    67         # 特征打平层
    68         self.flatten = layers.Flatten()
    69         # 2 分类全连接层
    70         self.fc = layers.Dense(1)
    71 
    72     def call(self, inputs, training=None):
    73         # 卷积-BN-激活函数:(4, 31, 31, 64)
    74         x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training) )
    75         # 卷积-BN-激活函数:(4, 14, 14, 128)
    76         x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
    77         # 卷积-BN-激活函数:(4, 6, 6, 256)
    78         x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
    79         # 卷积-BN-激活函数:(4, 4, 4, 512)
    80         x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))
    81         # 卷积-BN-激活函数:(4, 2, 2, 1024)
    82         x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training))
    83         # 卷积-BN-激活函数:(4, 1024)
    84         x = self.pool(x)
    85         # 打平
    86         x = self.flatten(x)
    87         # 输出,[b, 1024] => [b, 1]
    88         logits = self.fc(x)
    89 
    90         return logits

    训练:

      1 import os
      2 import glob
      3 import numpy as np
      4 
      5 import tensorflow as tf
      6 from tensorflow import keras
      7 
      8 from GAN import Generator, Discriminator
      9 from Dataset import make_anime_dataset
     10 
     11 from PIL import Image
     12 import scipy.misc
     13 import matplotlib.pyplot as plt
     14 
     15 
     16 def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
     17     # 计算判别器的误差函数
     18     # 采样生成图片
     19     fake_image = generator(batch_z, is_training)
     20     # 判定生成图片
     21     d_fake_logits = discriminator(fake_image, is_training)
     22     # 判定真实图片
     23     d_real_logits = discriminator(batch_x, is_training)
     24     # 真实图片与1 之间的误差
     25     d_loss_real = celoss_ones(d_real_logits)
     26     # 生成图片与0 之间的误差
     27     d_loss_fake = celoss_zeros(d_fake_logits)
     28     # 合并误差
     29     loss = d_loss_fake + d_loss_real
     30 
     31     return loss
     32 
     33 
     34 def celoss_ones(logits):
     35     # 计算属于与标签为1 的交叉熵
     36     y = tf.ones_like(logits)
     37     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
     38 
     39     return tf.reduce_mean(loss)
     40 
     41 
     42 def celoss_zeros(logits):
     43     # 计算属于与便签为0 的交叉熵
     44     y = tf.zeros_like(logits)
     45     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
     46 
     47     return tf.reduce_mean(loss)
     48 
     49 
     50 def g_loss_fn(generator, discriminator, batch_z, is_training):
     51     # 采样生成图片
     52     fake_image = generator(batch_z, is_training)
     53     # 在训练生成网络时,需要迫使生成图片判定为真
     54     d_fake_logits = discriminator(fake_image, is_training)
     55     # 计算生成图片与1 之间的误差
     56     loss = celoss_ones(d_fake_logits)
     57 
     58     return loss
     59 
     60 
     61 def save_result(val_out, val_block_size, image_path, color_mode):
     62     def preprocess(img):
     63         img = ((img + 1.0) * 127.5).astype(np.uint8)
     64         # img = img.astype(np.uint8)
     65         return img
     66 
     67     preprocesed = preprocess(val_out)
     68     final_image = np.array([])
     69     single_row = np.array([])
     70 
     71     for b in range(val_out.shape[0]):
     72         # concat image into a row
     73         if single_row.size == 0:
     74             single_row = preprocesed[b, :, :, :]
     75         else:
     76             single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
     77 
     78         # concat image row to final_image
     79         if (b + 1) % val_block_size == 0:
     80             if final_image.size == 0:
     81                 final_image = single_row
     82             else:
     83                 final_image = np.concatenate((final_image, single_row), axis=0)
     84 
     85             # reset single row
     86             single_row = np.array([])
     87 
     88     if final_image.shape[2] == 1:
     89         final_image = np.squeeze(final_image, axis=2)
     90     im = Image.fromarray(final_image)
     91     im.save('exam11_final_image.png')
     92     # Image.save(final_image)
     93     # Image(final_image).save(image_path)
     94 
     95 
     96 d_losses, g_losses = [], []
     97 
     98 
     99 def draw():
    100     plt.figure()
    101     plt.plot(d_losses, 'b', label='generator')
    102     plt.plot(g_losses, 'r', label='discriminator')
    103     plt.xlabel('Epoch')
    104     plt.ylabel('ACC')
    105     plt.legend()
    106     plt.savefig('exam11.1_train_test_VAE.png')
    107     plt.show()
    108 
    109 
    110 def main():
    111     batch_size = 64
    112     learning_rate = 0.0002
    113     z_dim = 100
    114     is_training = True
    115     epochs = 300
    116 
    117     img_path = glob.glob(r'G:2020pythonfacesfaces*.jpg')
    118     print('images num:', len(img_path))
    119     # 构建数据集对象,返回数据集Dataset 类和图片大小
    120     dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)  # (64, 64, 64, 3) (64, 64, 3)
    121     sample = next(iter(dataset))  # 采样  (64, 64, 64, 3)
    122     print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy())  # (64, 64, 64, 3) 1.0 -1.0
    123     dataset = dataset.repeat(100)  # 重复循环
    124     db_iter = iter(dataset)
    125 
    126     generator = Generator()  # 创建生成器
    127     generator.build(input_shape=(4, z_dim))
    128     discriminator = Discriminator()  # 创建判别器
    129     discriminator.build(input_shape=(4, 64, 64, 3))
    130     # 分别为生成器和判别器创建优化器
    131     g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    132     d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    133 
    134     # generator.load_weights('exam11.1_generator.ckpt')
    135     # discriminator.load_weights('exam11.1_discriminator.ckpt')
    136     # print('Loaded chpt!!')
    137 
    138     for epoch in range(epochs):  # 训练epochs 次
    139         # 1. 训练判别器
    140         for _ in range(5):
    141             # 采样隐藏向量
    142             batch_z = tf.random.normal([batch_size, z_dim])
    143             batch_x = next(db_iter)  # 采样真实图片
    144             # 判别器前向计算
    145             with tf.GradientTape() as tape:
    146                 d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
    147                 grads = tape.gradient(d_loss, discriminator.trainable_variables)
    148                 d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
    149 
    150         # 2. 训练生成器
    151         # 采样隐藏向量
    152         batch_z = tf.random.normal([batch_size, z_dim])
    153         batch_x = next(db_iter)  # 采样真实图片
    154         # 生成器前向计算
    155         with tf.GradientTape() as tape:
    156             g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
    157         grads = tape.gradient(g_loss, generator.trainable_variables)
    158         g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
    159 
    160         if epoch % 100 == 0:
    161             print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))  # 可视化
    162             z = tf.random.normal([100, z_dim])
    163             fake_image = generator(z, training=False)
    164             img_path = os.path.join('gan_images', 'gan-%d.png' % epoch)
    165             save_result(fake_image.numpy(), 10, img_path, color_mode='P')
    166 
    167             d_losses.append(float(d_loss))
    168             g_losses.append(float(g_loss))
    169 
    170             if epoch % 10000 == 1:
    171                 # print(d_losses)
    172                 # print(g_losses)
    173                 generator.save_weights('exam11.1_generator.ckpt')
    174                 discriminator.save_weights('exam11.1_discriminator.ckpt')
    175 
    176 
    177 if __name__ == '__main__':
    178     main()
    179     draw()

    没有结果,代码没有报错,个人认为还是受机器的限制;

    WGAN-GP:

     1 import tensorflow as tf
     2 from tensorflow.keras import layers, Model
     3 
     4 
     5 class Generator(Model):
     6     def __init__(self):
     7         super(Generator, self).__init__()
     8         # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
     9         self.fc = layers.Dense(3*3*512)
    10         self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
    11         self.bn1 = layers.BatchNormalization()
    12 
    13         self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
    14         self.bn2 = layers.BatchNormalization()
    15         self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
    16 
    17     def call(self, inputs, training=None):
    18         # [z, 100] => [z, 3*3*512]
    19         x = self.fc(inputs)
    20         x = tf.reshape(x, [-1, 3, 3, 512])
    21         x = tf.nn.leaky_relu(x)
    22 
    23         #
    24         x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
    25         x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
    26         x = self.conv3(x)
    27         x = tf.tanh(x)
    28 
    29         return x
    30 
    31 
    32 class Discriminator(Model):
    33     def __init__(self):
    34         super(Discriminator, self).__init__()
    35 
    36         # [b, 64, 64, 3] => [b, 1]
    37         self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
    38         self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
    39         self.bn2 = layers.BatchNormalization()
    40 
    41         self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
    42         self.bn3 = layers.BatchNormalization()
    43 
    44         # [b, h, w ,c] => [b, -1]
    45         self.flatten = layers.Flatten()
    46         self.fc = layers.Dense(1)
    47 
    48 
    49     def call(self, inputs, training=None):
    50         x = tf.nn.leaky_relu(self.conv1(inputs))
    51         x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
    52         x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
    53 
    54         # [b, h, w, c] => [b, -1]
    55         x = self.flatten(x)
    56 
    57         # [b, -1] => [b, 1]
    58         logits = self.fc(x)
    59         return logits
    60 
    61 
    62 def main():
    63     d = Discriminator()
    64     g = Generator()
    65 
    66     x = tf.random.normal([2, 64, 64, 3])
    67     z = tf.random.normal([2, 100])
    68 
    69     prob = d(x)
    70     print(prob)
    71     x_hat = g(z)
    72     print(x_hat.shape)
    73 
    74 
    75 if __name__ == '__main__':
    76     main()

    训练代码:

      1 import os
      2 import glob
      3 import numpy as np
      4 
      5 import tensorflow as tf
      6 from tensorflow import keras
      7 
      8 from WGAN import Generator, Discriminator
      9 from Dataset import make_anime_dataset
     10 
     11 from PIL import Image
     12 import matplotlib.pyplot as plt
     13 
     14 
     15 def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
     16     # 计算D 的损失函数
     17     fake_image = generator(batch_z, is_training) # 假样本
     18     d_fake_logits = discriminator(fake_image, is_training) # 假样本的输出
     19     d_real_logits = discriminator(batch_x, is_training) # 真样本的输出
     20     # 计算梯度惩罚项
     21     gp = gradient_penalty(discriminator, batch_x, fake_image)
     22     # WGAN-GP D 损失函数的定义,这里并不是计算交叉熵,而是直接最大化正样本的输出
     23     # 最小化假样本的输出和梯度惩罚项
     24     loss = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits) + 10. * gp
     25 
     26     return loss, gp
     27 
     28 
     29 def celoss_ones(logits):
     30     # 计算属于与标签为1 的交叉熵
     31     y = tf.ones_like(logits)
     32     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
     33 
     34     return tf.reduce_mean(loss)
     35 
     36 
     37 def celoss_zeros(logits):
     38     # 计算属于与便签为0 的交叉熵
     39     y = tf.zeros_like(logits)
     40     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
     41 
     42     return tf.reduce_mean(loss)
     43 
     44 
     45 def gradient_penalty(discriminator, batch_x, fake_image):
     46     # 梯度惩罚项计算函数
     47     batchsz = batch_x.shape[0]
     48 
     49     # 每个样本均随机采样t,用于插值
     50     t = tf.random.uniform([batchsz, 1, 1, 1])
     51     # 自动扩展为x 的形状,[b, 1, 1, 1] => [b, h, w, c]
     52     t = tf.broadcast_to(t, batch_x.shape)
     53 
     54     # 在真假图片之间做线性插值
     55     interplate = t * batch_x + (1 - t) * fake_image
     56     # 在梯度环境中计算D 对插值样本的梯度
     57     with tf.GradientTape() as tape:
     58         tape.watch([interplate])  # 加入梯度观察列表
     59         d_interplote_logits = discriminator(interplate)
     60     grads = tape.gradient(d_interplote_logits, interplate)
     61 
     62     # 计算每个样本的梯度的范数:[b, h, w, c] => [b, -1]
     63     grads = tf.reshape(grads, [grads.shape[0], -1])
     64     gp = tf.norm(grads, axis=1)  # [b]
     65     # 计算梯度惩罚项
     66     gp = tf.reduce_mean((gp - 1.) ** 2)
     67 
     68     return gp
     69 
     70 
     71 def g_loss_fn(generator, discriminator, batch_z, is_training):
     72     # 生成器的损失函数
     73     fake_image = generator(batch_z, is_training)
     74     d_fake_logits = discriminator(fake_image, is_training)
     75     # WGAN-GP G 损失函数,最大化假样本的输出值
     76     loss = - tf.reduce_mean(d_fake_logits)
     77 
     78     return loss
     79 
     80 
     81 def save_result(val_out, val_block_size, image_path, color_mode):
     82     def preprocess(img):
     83         img = ((img + 1.0) * 127.5).astype(np.uint8)
     84         # img = img.astype(np.uint8)
     85         return img
     86 
     87     preprocesed = preprocess(val_out)
     88     final_image = np.array([])
     89     single_row = np.array([])
     90 
     91     for b in range(val_out.shape[0]):
     92         # concat image into a row
     93         if single_row.size == 0:
     94             single_row = preprocesed[b, :, :, :]
     95         else:
     96             single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
     97 
     98         # concat image row to final_image
     99         if (b + 1) % val_block_size == 0:
    100             if final_image.size == 0:
    101                 final_image = single_row
    102             else:
    103                 final_image = np.concatenate((final_image, single_row), axis=0)
    104 
    105             # reset single row
    106             single_row = np.array([])
    107 
    108     if final_image.shape[2] == 1:
    109         final_image = np.squeeze(final_image, axis=2)
    110     im = Image.fromarray(final_image)
    111     im.save('exam11_WGAN_final_image.png')
    112     # Image.save(final_image)
    113     # Image(final_image).save(image_path)
    114 
    115 
    116 d_losses, g_losses = [], []
    117 
    118 
    119 def draw():
    120     plt.figure()
    121     plt.plot(d_losses, 'b', label='generator')
    122     plt.plot(g_losses, 'r', label='discriminator')
    123     plt.xlabel('Epoch')
    124     plt.ylabel('ACC')
    125     plt.legend()
    126     plt.savefig('exam11.2_train_test_VAE.png')
    127     plt.show()
    128 
    129 
    130 def main():
    131     batch_size = 512
    132     learning_rate = 0.002
    133     z_dim = 100
    134     is_training = True
    135     epochs = 300
    136 
    137     img_path = glob.glob(r'G:2020pythonfacesfaces*.jpg')
    138     print('images num:', len(img_path))  # images num: 51223
    139     # 构建数据集对象,返回数据集Dataset 类和图片大小
    140     dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)  # (512, 64, 64, 3) (64, 64, 3)
    141     sample = next(iter(dataset))  # 采样  (512, 64, 64, 3)
    142     print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy())  # (512, 64, 64, 3) 1.0 -1.0
    143     dataset = dataset.repeat(100)  # 重复循环
    144     db_iter = iter(dataset)
    145 
    146     generator = Generator()  # 创建生成器
    147     generator.build(input_shape=(None, z_dim))
    148     discriminator = Discriminator()  # 创建判别器
    149     discriminator.build(input_shape=(None, 64, 64, 3))
    150     # 分别为生成器和判别器创建优化器
    151     g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    152     d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    153 
    154     # generator.load_weights('exam11.1_generator.ckpt')
    155     # discriminator.load_weights('exam11.1_discriminator.ckpt')
    156     # print('Loaded chpt!!')
    157 
    158     for epoch in range(epochs):  # 训练epochs 次
    159         # 采样隐藏向量
    160         batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
    161         batch_x = next(db_iter)
    162 
    163         # 判别器前向计算
    164         with tf.GradientTape() as tape:
    165             d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
    166         grads = tape.gradient(d_loss, discriminator.trainable_variables)
    167         d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
    168 
    169         with tf.GradientTape() as tape:
    170             g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
    171         grads = tape.gradient(g_loss, generator.trainable_variables)
    172         g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
    173 
    174         if epoch % 100 == 0:
    175             print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss), 'gp:', float(gp))
    176             z = tf.random.uniform([100, z_dim])
    177 
    178             fake_image = generator(z, training=False)
    179             img_path = os.path.join('images', 'wgan-%d.png'%epoch)
    180             save_result(fake_image.numpy(), 10, img_path, color_mode='P')
    181 
    182         if epoch % 10000 == 1:
    183             # print(d_losses)
    184             # print(g_losses)
    185             generator.save_weights('exam11.2_generator.ckpt')
    186             discriminator.save_weights('exam11.2_discriminator.ckpt')
    187 
    188 
    189 if __name__ == '__main__':
    190     main()
    191     draw()

    同样没有结果,后面有条件再试一试;

    这一部分对算法的要求高,要看懂他,得花时间看,

    我没有去研究它,只是看代码去了。

  • 相关阅读:
    Length of Last Word
    Remove Duplicates from Sorted Array II
    Sum Root to Leaf Numbers
    Valid Parentheses
    Set Matrix Zeroes
    Symmetric Tree
    Unique Binary Search Trees
    110Balanced Binary Tree
    Match:Blue Jeans(POJ 3080)
    Match:Seek the Name, Seek the Fame(POJ 2752)
  • 原文地址:https://www.cnblogs.com/heze/p/12390926.html
Copyright © 2011-2022 走看看