zoukankan      html  css  js  c++  java
  • 使用GAN生成手写数字样本

    from __future__ import print_function, division
    
    from keras.datasets import mnist
    from keras.layers import Input, Dense, Reshape, Flatten, Dropout
    from keras.layers import BatchNormalization, Activation, ZeroPadding2D
    from keras.layers.advanced_activations import LeakyReLU
    from keras.layers.convolutional import UpSampling2D, Conv2D
    from keras.models import Sequential, Model
    from keras.optimizers import Adam
    
    import matplotlib.pyplot as plt
    
    import sys
    import os
    import numpy as np
    
    class GAN():
        def __init__(self):
            # --------------------------------- #
            #   行28,列28,也就是mnist的shape
            # --------------------------------- #
            self.img_rows = 28
            self.img_cols = 28
            self.channels = 1
            # 28,28,1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
            self.latent_dim = 100
            # adam优化器
            optimizer = Adam(0.0002, 0.5)
    
            self.discriminator = self.build_discriminator()
            self.discriminator.compile(loss='binary_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])
    
            self.generator = self.build_generator()
            gan_input = Input(shape=(self.latent_dim,))
            img = self.generator(gan_input)
            # 在训练generate的时候不训练discriminator
            self.discriminator.trainable = False
            # 对生成的假图片进行预测
            validity = self.discriminator(img)
            self.combined = Model(gan_input, validity)
            self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    
        def build_generator(self):
            # --------------------------------- #
            #   生成器,输入一串随机数字
            # --------------------------------- #
            model = Sequential()
    
            model.add(Dense(256, input_dim=self.latent_dim))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(1024))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(np.prod(self.img_shape), activation='tanh'))
            model.add(Reshape(self.img_shape))
    
            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
    
            return Model(noise, img)
    
        def build_discriminator(self):
            # ----------------------------------- #
            #   评价器,对输入进来的图片进行评价
            # ----------------------------------- #
            model = Sequential()
            # 输入一张图片
            model.add(Flatten(input_shape=self.img_shape))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(256))
            model.add(LeakyReLU(alpha=0.2))
            # 判断真伪
            model.add(Dense(1, activation='sigmoid'))
    
            img = Input(shape=self.img_shape)
            validity = model(img)
    
            return Model(img, validity)
    
        def train(self, epochs, batch_size=128, sample_interval=50):
            # 获得数据
            (X_train, _), (_, _) = mnist.load_data()
    
            # 进行标准化
            X_train = X_train / 127.5 - 1.
            X_train = np.expand_dims(X_train, axis=3)
    
            # 创建标签
            valid = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))
    
            for epoch in range(epochs):
    
                # --------------------------- #
                #   随机选取batch_size个图片
                #   对discriminator进行训练
                # --------------------------- #
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    
                gen_imgs = self.generator.predict(noise)
    
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
                # --------------------------- #
                #  训练generator
                # --------------------------- #
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                g_loss = self.combined.train_on_batch(noise, valid)
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    
                if epoch % sample_interval == 0:
                    self.sample_images(epoch)
    
        def sample_images(self, epoch):
    
            r, c = 5, 5
            noise = np.random.normal(0, 1, (r * c, self.latent_dim))
            gen_imgs = self.generator.predict(noise)
            gen_imgs = 0.5 * gen_imgs + 0.5
    
            fig, axs = plt.subplots(r, c)
            cnt = 0
            for i in range(r):
                for j in range(c):
                    axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("images/%d.png" % epoch)
            plt.close()
    
    if __name__ == '__main__':
        if not os.path.exists("./images"):
            os.makedirs("./images")
        gan = GAN()
        gan.train(epochs=100, batch_size=256, sample_interval=200)

  • 相关阅读:
    从PowerDesigner概念设计模型(CDM)中的3种实体关系说起
    详细整数分区的问题解释
    JavaWeb显示器
    mysql中国的内容php网页乱码问题
    Linux开发环境的搭建和使用——Linux本必备软件SSH
    【Android UI设计和开发】动画(Animation)详细说明(一)
    应该是什么在预新手发“外链”(4)最终的外链的方法
    选择29部分有用jQuery应用程序插件(免费点数下载)
    设计模式--装饰图案
    Webserver管理系列:3、Windows Update
  • 原文地址:https://www.cnblogs.com/ywqtro/p/14817134.html
Copyright © 2011-2022 走看看