zoukankan      html  css  js  c++  java
  • GAN生成对抗网络-DCGAN原理与基本实现-深度卷积生成对抗网络03

    什么是DCGAN
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    实现代码

    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import glob
    import os
    
    # 显存自适应分配
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu,True)
    
    gpu_ok = tf.test.is_gpu_available()
    print("tf version:", tf.__version__)
    print("use GPU", gpu_ok) # 判断是否使用gpu进行训练
    
    # 手写数据集
    (train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()
    

    在这里插入图片描述
    在这里插入图片描述

    train_images = train_images.reshape(train_images.shape[0],28,28,1).astype("float32")
    

    在这里插入图片描述

    # 归一化
    train_images = (train_images-127.5)/127.5
    
    BATCH_SIZE = 256
    BUFFER_SIZE = 60000
    
    # 创建数据集
    datasets = tf.data.Dataset.from_tensor_slices(train_images)
    
    # 乱序
    datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    

    在这里插入图片描述
    编写模型

    # 生成器模型
    def generator_model():
        model = keras.Sequential() # 顺序模型
        model.add(layers.Dense(7*7*256,input_shape=(100,),use_bias=False)) # 输出7*7*256个单元,随机数输入数据形状长度100的向量
        model.add(layers.BatchNormalization()) # 批处理
        model.add(layers.LeakyReLU())# LeakyReLU()激活
        
        model.add(layers.Reshape((7,7,256))) # 7*7*256
        model.add(layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding="same",use_bias=False)) # 反卷积
        model.add(layers.BatchNormalization()) # 批处理
        model.add(layers.LeakyReLU())# LeakyReLU()激活    # 7*7*128
        
        model.add(layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding="same",use_bias=False)) # 反卷积
        model.add(layers.BatchNormalization()) # 批处理
        model.add(layers.LeakyReLU())# LeakyReLU()激活    # 14*14*64
        
        model.add(layers.Conv2DTranspose(1,(5,5),
                                         strides=(2,2),
                                         padding="same",
                                         use_bias=False,
                                         activation="tanh")) # 反卷积        # 28*28*1
    
        return model
    
    # 判别模型
    def discriminator_model():
        model = keras.Sequential()
        model.add(layers.Conv2D(64,(5,5),
                                strides=(2,2),
                                padding="same",
                                input_shape = (28,28,1)))
        
        model.add(layers.LeakyReLU())
        model.add(layers.Dropout(0.3))
        
        model.add(layers.Conv2D(128,(5,5),
                                strides=(2,2),
                                padding="same"))
        
        model.add(layers.LeakyReLU())
        model.add(layers.Dropout(0.3))
        
        model.add(layers.Conv2D(256,(5,5),
                                strides=(2,2),
                                padding="same"))
        
        model.add(layers.LeakyReLU())
        
        model.add(layers.Flatten())
        
        model.add(layers.Dense(1))
        
        
        return model
    
    # 编写loss    binary_crossentropy(对数损失函数)即 log loss,与 sigmoid 相对应的损失函数,针对于二分类问题。
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    # 辨别器loss
    def discriminator_loss(real_out,fake_out):
        read_loss = cross_entropy(tf.ones_like(real_out),real_out) # 使用binary_crossentropy 对真实图片判别为1
        fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out) # 生成的图片 判别为0
        return read_loss + fake_loss
    
    # 生成器loss
    def generator_loss(fake_out):
        return cross_entropy(tf.ones_like(fake_out),fake_out) # 希望对生成的图片返回为1
    
    # 优化器
    generator_opt = tf.keras.optimizers.Adam(1e-4) # 学习速率1e-4   0.0001
    discriminator_opt = tf.keras.optimizers.Adam(1e-4)
    
    EPOCHS = 100 # 训练步数
    noise_dim = 100 
    
    num_exp_to_generate = 16
    
    seed = tf.random.normal([num_exp_to_generate,noise_dim]) # 16,100 # 生成16个样本,长度为100的随机数
    
    generator = generator_model()
    
    discriminator = discriminator_model()
    
    # 训练一个epoch
    def train_step(images):
        noise = tf.random.normal([BATCH_SIZE,noise_dim])
        
        with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape: # 梯度
            real_out = discriminator(images,training=True)
            
            gen_image = generator(noise,training=True)
            
            fake_out = discriminator(gen_image,training=True)
            
            gen_loss = generator_loss(fake_out)
            disc_loss = discriminator_loss(real_out,fake_out)
            
        gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
        gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
        generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
        discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
    
    def genrate_plot_image(gen_model,test_noise):
        pre_images = gen_model(test_noise,training=False)
        fig = plt.figure(figsize=(4,4))
        for i in range(pre_images.shape[0]):
            plt.subplot(4,4,i+1)
            plt.imshow((pre_images[i,:,:,0]+1)/2,cmap="gray")
            plt.axis("off")
        plt.show()
    
    def train(dataset,epochs):
        for epoch in range(epochs):
            for image_batch in dataset:
                train_step(image_batch)
                print(".",end="")
            genrate_plot_image(generator,seed)
    
    # 训练模型
    train(datasets,EPOCHS)
    

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    CodeForces gym Nasta Rabbara lct
    bzoj 4025 二分图 lct
    CodeForces 785E Anton and Permutation
    bzoj 3669 魔法森林
    模板汇总——快读 fread
    bzoj2049 Cave 洞穴勘测 lct
    bzoj 2002 弹飞绵羊 lct裸题
    HDU 6394 Tree 分块 || lct
    HDU 6364 Ringland
    nyoj221_Tree_subsequent_traversal
  • 原文地址:https://www.cnblogs.com/gemoumou/p/14186251.html
Copyright © 2011-2022 走看看