zoukankan      html  css  js  c++  java
  • 使用GAN生成手写数字样本

    from __future__ import print_function, division
    
    from keras.datasets import mnist
    from keras.layers import Input, Dense, Reshape, Flatten, Dropout
    from keras.layers import BatchNormalization, Activation, ZeroPadding2D
    from keras.layers.advanced_activations import LeakyReLU
    from keras.layers.convolutional import UpSampling2D, Conv2D
    from keras.models import Sequential, Model
    from keras.optimizers import Adam
    
    import matplotlib.pyplot as plt
    
    import sys
    import os
    import numpy as np
    
    class GAN():
        def __init__(self):
            # --------------------------------- #
            #   行28,列28,也就是mnist的shape
            # --------------------------------- #
            self.img_rows = 28
            self.img_cols = 28
            self.channels = 1
            # 28,28,1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
            self.latent_dim = 100
            # adam优化器
            optimizer = Adam(0.0002, 0.5)
    
            self.discriminator = self.build_discriminator()
            self.discriminator.compile(loss='binary_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])
    
            self.generator = self.build_generator()
            gan_input = Input(shape=(self.latent_dim,))
            img = self.generator(gan_input)
            # 在训练generate的时候不训练discriminator
            self.discriminator.trainable = False
            # 对生成的假图片进行预测
            validity = self.discriminator(img)
            self.combined = Model(gan_input, validity)
            self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    
        def build_generator(self):
            # --------------------------------- #
            #   生成器,输入一串随机数字
            # --------------------------------- #
            model = Sequential()
    
            model.add(Dense(256, input_dim=self.latent_dim))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(1024))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(np.prod(self.img_shape), activation='tanh'))
            model.add(Reshape(self.img_shape))
    
            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
    
            return Model(noise, img)
    
        def build_discriminator(self):
            # ----------------------------------- #
            #   评价器,对输入进来的图片进行评价
            # ----------------------------------- #
            model = Sequential()
            # 输入一张图片
            model.add(Flatten(input_shape=self.img_shape))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(256))
            model.add(LeakyReLU(alpha=0.2))
            # 判断真伪
            model.add(Dense(1, activation='sigmoid'))
    
            img = Input(shape=self.img_shape)
            validity = model(img)
    
            return Model(img, validity)
    
        def train(self, epochs, batch_size=128, sample_interval=50):
            # 获得数据
            (X_train, _), (_, _) = mnist.load_data()
    
            # 进行标准化
            X_train = X_train / 127.5 - 1.
            X_train = np.expand_dims(X_train, axis=3)
    
            # 创建标签
            valid = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))
    
            for epoch in range(epochs):
    
                # --------------------------- #
                #   随机选取batch_size个图片
                #   对discriminator进行训练
                # --------------------------- #
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    
                gen_imgs = self.generator.predict(noise)
    
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
                # --------------------------- #
                #  训练generator
                # --------------------------- #
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                g_loss = self.combined.train_on_batch(noise, valid)
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    
                if epoch % sample_interval == 0:
                    self.sample_images(epoch)
    
        def sample_images(self, epoch):
    
            r, c = 5, 5
            noise = np.random.normal(0, 1, (r * c, self.latent_dim))
            gen_imgs = self.generator.predict(noise)
            gen_imgs = 0.5 * gen_imgs + 0.5
    
            fig, axs = plt.subplots(r, c)
            cnt = 0
            for i in range(r):
                for j in range(c):
                    axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("images/%d.png" % epoch)
            plt.close()
    
    if __name__ == '__main__':
        if not os.path.exists("./images"):
            os.makedirs("./images")
        gan = GAN()
        gan.train(epochs=100, batch_size=256, sample_interval=200)

  • 相关阅读:
    C#执行sql文件
    "C:Program FilesInternet Exploreriexplore.exe" -extoff 无加载项启动IE 浏览器打开时全屏模式
    sql server备份与还原 sql语句
    触摸屏\串口服务器\串口
    USB串行端口
    选取所有表单元素
    C# Bitmap类型与Byte[]类型相互转化
    本文介绍C# BitmapData
    C#多线程学习(六) 互斥对象
    C# List 复制克隆副本
  • 原文地址:https://www.cnblogs.com/ywqtro/p/14817134.html
Copyright © 2011-2022 走看看