zoukankan      html  css  js  c++  java
  • GAN-生成式对抗网络(keras实现)

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是最近超级火的一个无监督学习方法,它主要由两部分组成,一部分是生成模型G(generator),另一部分是判别模型D(discriminator),它的训练过程可大致描述如下:

    生成模型通过接收一个随机噪声来生成图片,判别模型用来判断这个图片是不是“真实的”,也就是说,生成网络的目标是尽量生成真实的图片去欺骗判别网络,判别网络的目标就是把G生成的图片和真实的图片区分开来,从而构成一个动态的博弈过程。

    GAN主要用来解决的问题是:在数据量不足的情况下,通过小型数据集去生成一些数据

    从理论上来说,GAN系列神经网络可以用来模拟任何数据分布,但是目前更主要用于图像。

    而事实也证明,GAN生成的数据是可以直接用在实际的图像问题上的,如行人重识别数据集,细粒度识别等。

     (GAN的网络结构及训练流程)

    下面是用keras实现的GAN:

      1 from __future__ import print_function, division
      2 
      3 from keras.datasets import mnist
      4 from keras.layers import Input, Dense, Reshape, Flatten, Dropout
      5 from keras.layers import BatchNormalization, Activation, ZeroPadding2D
      6 from keras.layers.advanced_activations import LeakyReLU
      7 from keras.layers.convolutional import UpSampling2D, Conv2D
      8 from keras.models import Sequential, Model
      9 from keras.optimizers import Adam
     10 
     11 import matplotlib.pyplot as plt
     12 
     13 import sys
     14 
     15 import numpy as np
     16 
     17 class GAN():
     18     def __init__(self):
     19         # 定义输入图像尺寸及通道
     20         self.img_rows = 28
     21         self.img_cols = 28
     22         self.channels = 1
     23         self.img_shape = (self.img_rows, self.img_cols, self.channels)
     24         self.latent_dim = 100
     25 
     26         # 设置网络优化器
     27         optimizer = Adam(0.0002, 0.5)
     28 
     29         # 构建判别网络
     30         self.discriminator = self.build_discriminator()
     31         self.discriminator.compile(loss='binary_crossentropy',
     32             optimizer=optimizer,
     33             metrics=['accuracy'])
     34 
     35         # 构建生成网络
     36         self.generator = self.build_generator()
     37 
     38         # 生成器根据噪声生成图像
     39         z = Input(shape=(self.latent_dim,))
     40         img = self.generator(z)
     41 
     42         # 在联合模型中,设置判别器参数不可训练
     43         self.discriminator.trainable = False
     44 
     45         # 判别器验证生成图像
     46         validity = self.discriminator(img)
     47 
     48         # 训练生成器来欺骗判别器
     49         self.combined = Model(z, validity)
     50         self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
     51 
     52 
     53     # 生成器结构
     54     def build_generator(self):
     55 
     56         model = Sequential()
     57 
     58         model.add(Dense(256, input_dim=self.latent_dim))
     59         model.add(LeakyReLU(alpha=0.2))
     60         model.add(BatchNormalization(momentum=0.8))
     61         model.add(Dense(512))
     62         model.add(LeakyReLU(alpha=0.2))
     63         model.add(BatchNormalization(momentum=0.8))
     64         model.add(Dense(1024))
     65         model.add(LeakyReLU(alpha=0.2))
     66         model.add(BatchNormalization(momentum=0.8))
     67         model.add(Dense(np.prod(self.img_shape), activation='tanh'))
     68         model.add(Reshape(self.img_shape))
     69 
     70         model.summary()
     71 
     72         noise = Input(shape=(self.latent_dim,))
     73         img = model(noise)
     74 
     75         return Model(noise, img)
     76 
     77     # 判别器结构
     78     def build_discriminator(self):
     79 
     80         model = Sequential()
     81 
     82         model.add(Flatten(input_shape=self.img_shape))
     83         model.add(Dense(512))
     84         model.add(LeakyReLU(alpha=0.2))
     85         model.add(Dense(256))
     86         model.add(LeakyReLU(alpha=0.2))
     87         model.add(Dense(1, activation='sigmoid'))
     88         model.summary()
     89 
     90         img = Input(shape=self.img_shape)
     91         validity = model(img)
     92 
     93         return Model(img, validity)
     94 
     95     # 定义训练过程
     96     def train(self, epochs, batch_size=128, sample_interval=50):
     97         (X_train, _), (_, _) = mnist.load_data()
     98 
     99         X_train = X_train / 127.5 - 1.
    100         X_train = np.expand_dims(X_train, axis=3)
    101 
    102         valid = np.ones((batch_size, 1))
    103         fake = np.zeros((batch_size, 1))
    104 
    105         for epoch in range(epochs):
    106 
    107             idx = np.random.randint(0, X_train.shape[0], batch_size)
    108             imgs = X_train[idx]
    109 
    110             noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    111 
    112             gen_imgs = self.generator.predict(noise)
    113 
    114             d_loss_real = self.discriminator.train_on_batch(imgs, valid)
    115             d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
    116             d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    117 
    118             noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    119 
    120             # 根据判别器valid训练生成器
    121             g_loss = self.combined.train_on_batch(noise, valid)
    122 
    123             print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    124 
    125             # 保存生成图像
    126             if epoch % sample_interval == 0:
    127                 self.sample_images(epoch)
    128 
    129     def sample_images(self, epoch):
    130         r, c = 5, 5
    131         noise = np.random.normal(0, 1, (r * c, self.latent_dim))
    132         gen_imgs = self.generator.predict(noise)
    133 
    134         gen_imgs = 0.5 * gen_imgs + 0.5
    135 
    136         fig, axs = plt.subplots(r, c)
    137         cnt = 0
    138         for i in range(r):
    139             for j in range(c):
    140                 axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
    141                 axs[i,j].axis('off')
    142                 cnt += 1
    143         fig.savefig("images/%d.png" % epoch)
    144         plt.close()
    145 
    146 
    147 if __name__ == '__main__':
    148     gan = GAN()
    149     gan.train(epochs=30000, batch_size=32, sample_interval=200)

    程序初始运行结果如下:

     训练完成后效果如下:

  • 相关阅读:
    asp.net mvc 中使用async/await异步编程
    简述C#中浅复制和深复制
    Angular:自定义表单控件
    Angular:Reactive Form的使用方法和自定义验证器
    Angular:ViewProviders和Providers的区别
    Angular:OnPush变化检测策略介绍
    Angular:利用内容投射向组件输入ngForOf模板
    在Angular中利用trackBy来提升性能
    Angular @HostBinding()和@HostListener()用法
    Angular利用@ViewChild在父组件执行子组件的方法
  • 原文地址:https://www.cnblogs.com/zdm-code/p/13856775.html
Copyright © 2011-2022 走看看