zoukankan      html  css  js  c++  java
  • tensorflow2.0—— GAN实战代码

    from  tensorflow import keras
    import tensorflow as tf
    from  tensorflow.keras import layers
    import numpy as np
    import os
    import matplotlib.pyplot as plt
    
    #   设置相关底层配置
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    
    # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    # os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 使用第2块gpu
    
    #   拼接图片
    def my_save_img(data,save_path):
        #   新图拼接行列
        r_c = 10
        len_data = data.shape[0]
        each_pix = 64
        save_img_path = save_path
        new_img = np.zeros((r_c*each_pix,r_c*each_pix,3))
        for index,each_img in enumerate(data[:r_c*r_c]):
            # print('each_img.shape:',each_img.shape,np.max(each_img),np.min(each_img))
            each_img  = (each_img+1)/2
            # print('each_img.shape:', each_img.shape, np.max(each_img), np.min(each_img))
            row_start = int(index/r_c) * each_pix
            col_start = (index%r_c)*each_pix
            # print(index,row_start,col_start)
            new_img[row_start:row_start+each_pix,col_start:col_start+each_pix,:] = each_img
            # print('new_img:',new_img)
    
        plt.imsave(save_img_path,new_img)
    
    class Generator(keras.Model):
        def __init__(self):
            super(Generator,self).__init__()
            # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
            self.fc = layers.Dense(3 * 3 * 512)
    
            self.Tconv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
            self.bn1 = layers.BatchNormalization()
    
            self.Tconv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
            self.bn2 = layers.BatchNormalization()
    
            self.Tconv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
    
        def call(self, inputs, training=None, mask=None):
            # [z, 100] => [z, 3*3*512]
            x = self.fc(inputs)
            x = tf.reshape(x, [-1, 3, 3, 512])
            x = tf.nn.leaky_relu(x)
    
            #
            x = tf.nn.leaky_relu(self.bn1(self.Tconv1(x), training=training))
            x = tf.nn.leaky_relu(self.bn2(self.Tconv2(x), training=training))
            x = self.Tconv3(x)
            x = tf.tanh(x)
    
            return x
    
    class Discriminator(keras.Model):
        def __init__(self):
            super(Discriminator,self).__init__()
            # [b, 64, 64, 3] => [b, 1]
    
            self.conv1 = layers.Conv2D(64,5,3,'valid')
    
            self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
            self.bn2 = layers.BatchNormalization()
    
            self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
            self.bn3 = layers.BatchNormalization()
    
            #   [b,h,w,3] => [b,-1]
            self.flatten = layers.Flatten()
            self.fc = layers.Dense(1)
    
        def call(self, inputs, training=None, mask=None):
    
            x = tf.nn.leaky_relu(self.conv1(inputs))
            x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training = training))
            x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
    
            #   打平
            x = self.flatten(x)
            #   [b,-1] => [b,1]
            logits = self.fc(x)
            return logits
    
    def main():
        #   超参数
        z_dim = 100
        epochs = 3000000
        batch_size = 1024
        learning_rate = 0.002
        is_training = True
    
        img_data = np.load('img.npy')
        train_db = tf.data.Dataset.from_tensor_slices(img_data).shuffle(10000).batch(batch_size)
        sample = next(iter(train_db))
        print(sample.shape, tf.reduce_max(sample).numpy(),
              tf.reduce_min(sample).numpy())
    
        train_db = train_db.repeat()
        db_iter = iter(train_db)
    
        #   判别器
        d = Discriminator()
        # d.build(input_shape=(None, 64, 64, 3))
        #   生成器
        g = Generator()
        # g.build(input_shape=(None, z_dim))
    
        #   分别定义优化器
        g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
        d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    
        for epoch in range(epochs):
            batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
            batch_x = next(db_iter)
    
            # train D
            with tf.GradientTape() as tape:
                # 1. treat real image as real
                # 2. treat generated image as fake
                fake_image = g(batch_z, is_training)
                d_fake_logits = d(fake_image, is_training)
                d_real_logits = d(batch_x, is_training)
    
                d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real_logits,labels=tf.ones_like(d_real_logits))
                # d_loss_real = tf.reduce_mean(d_loss_real)
                d_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits))
                # d_loss_fake = tf.reduce_mean(d_loss_fake)
    
                d_loss = d_loss_fake + d_loss_real
            grads = tape.gradient(d_loss, d.trainable_variables)
            d_optimizer.apply_gradients(zip(grads, d.trainable_variables))
    
            with tf.GradientTape() as tape:
                fake_image = g(batch_z, is_training)
                d_fake_logits = d(fake_image, is_training)
                g_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits))
                # g_loss = tf.reduce_mean(g_loss)
            grads = tape.gradient(g_loss, g.trainable_variables)
            g_optimizer.apply_gradients(zip(grads, g.trainable_variables))
            if epoch % 10 == 0:
                # print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
                print(epoch, 'd-loss:', d_loss.numpy(), 'g-loss:', g_loss.numpy())
                if epoch % 50 == 0:
                    z = tf.random.uniform([225,z_dim])
                    fake_image = g(z,training = False)
                    img_path = os.path.join('g_pic2', 'gan-%d.png'%epoch)
                    my_save_img(fake_image,img_path)
    
    
    if __name__ == '__main__':
        main()
  • 相关阅读:
    206.反转链表
    gprof
    Java【Stream流、方法引用】学习笔记
    Java【函数式接口(Supplier、Comsumer、Predicate、Function)】学习笔记
    Python exec 内置语句
    Python sorted() 函数
    Python oct() 函数
    Python id() 函数
    Python dir() 函数
    软件测试的方法
  • 原文地址:https://www.cnblogs.com/cxhzy/p/14268107.html
Copyright © 2011-2022 走看看