zoukankan      html  css  js  c++  java
  • GAN模型生成手写字

    概述:在前期的文章中,我们用TensorFlow完成了对手写数字的识别,得到了94.09%的识别准确度,效果还算不错。在这篇文章中,笔者将带领大家用GAN模型,生成我们想要的手写数字。

    GAN简介

    对抗性生成网络(GenerativeAdversarial Network),由 Ian Goodfellow 首先提出,由两个网络组成,分别是generator网络(用于生成)和discriminator网络(用于判别)。GAN网络的目的就是使其自己生成一副图片,比如说经过对一系列猫的图片的学习,generator网络可以自己“绘制”出一张猫的图片,且尽量真实。discriminator网络则是用来进行判断的,将一张真实的图片和一张由generator网络生成的照片同时交给discriminator网络,不断训练discriminator网络,使其可以准确将discriminator网络生成的“假图片”找出来。就这样,generator网络不断改进使其可以骗过discriminator网络,而discriminator网络不断改进使其可以更准确找到“假图片”,这种相互促进相互对抗的关系,就叫做对抗网络。图一中展示了GAN模型的结构。

    思路梳理

    将MNIST数据集中标签为0的图片提取出来,然后训练discriminator网络,进行手写数字0识别,接着让generator产生一张随机图片,让训练好的discriminator去识别这张生成的图片,不断训练discriminator,直到discriminator网络将生成的图片当做数字0为止。

    生成“假图片

    生成一张随机像素的28*28的图片,分别进行全连接,Leaky ReLU函数激活,dropout处理(随机丢弃一些神经元,防止过拟合),全连接,tanh函数激活,最终生成一张“假图片”,TensorFlow代码如下:

    1def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    2    with tf.variable_scope("generator", reuse=reuse):
    3        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
    4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
    5        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
    6        logits = tf.layers.dense(hidden1, out_dim)
    7        outputs = tf.tanh(logits)
    8        return logits, outputs

    图像判别

    将需要进行判别的图片先后经过全连接,Leaky ReLU函数激活,全连接,sigmoid函数激活处理,最终输出图片的识别结果,TensorFlow代码如下:

    1def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    2    with tf.variable_scope("discriminator", reuse=reuse):
    3        hidden1 = tf.layers.dense(img, n_units)
    4        hidden1 = tf.maximum(alpha * hidden1, hidden1)
    5        logits = tf.layers.dense(hidden1, 1)
    6        outputs = tf.sigmoid(logits)
    7        return logits, outputs

    完整代码

    GAN手写数字识别的完整代码如下:

      1import tensorflow as tf
     2from tensorflow.examples.tutorials.mnist import input_data
     3import matplotlib.pyplot as plt
     4import numpy as np
     5
     6mnist = input_data.read_data_sets("E:/Tensor/MNIST_data/")
     7img = mnist.train.images[50]
     8
     9
    10def get_inputs(real_size, noise_size):
    11    real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
    12    noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
    13    return real_img, noise_img
    14
    15
    16# 生成图像
    17def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    18    with tf.variable_scope("generator", reuse=reuse):
    19        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
    20        hidden1 = tf.maximum(alpha * hidden1, hidden1)
    21        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
    22        logits = tf.layers.dense(hidden1, out_dim)
    23        outputs = tf.tanh(logits)
    24        return logits, outputs
    25
    26
    27# 图像判别
    28def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    29    with tf.variable_scope("discriminator", reuse=reuse):
    30        hidden1 = tf.layers.dense(img, n_units)
    31        hidden1 = tf.maximum(alpha * hidden1, hidden1)
    32        logits = tf.layers.dense(hidden1, 1)
    33        outputs = tf.sigmoid(logits)
    34        return logits, outputs
    35#真实图像size
    36img_size = mnist.train.images[0].shape[0]
    37#传入generator的噪声size
    38noise_size = 100
    39#生成器隐层参数
    40g_units = 128
    41#判别器隐层参数
    42d_units = 128
    43#Leaky ReLU参数
    44alpha = 0.01
    45#学习率
    46learning_rate = 0.001
    47#label smoothing
    48smooth = 0.1
    49tf.reset_default_graph()
    50real_img, noise_img = get_inputs(img_size, noise_size)
    51g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
    52
    53d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
    54d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
    55
    56d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    57    logits=d_logits_real, labels=tf.ones_like(d_logits_real)
    58) * (1 - smooth))
    59d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    60    logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
    61))
    62d_loss = tf.add(d_loss_real, d_loss_fake)
    63g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    64    logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
    65) * (1 - smooth))
    66
    67train_vars = tf.trainable_variables()
    68g_vars = [var for var in train_vars if var.name.startswith("generator")]
    69d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    70
    71d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
    72g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
    73
    74
    75epochs = 10000
    76samples = []
    77n_sample = 10
    78losses = []
    79
    80i = j = 0
    81while i<10000:
    82    if mnist.train.labels[j] == 0:
    83        samples.append(mnist.train.images[j])
    84        i += 1
    85    j += 1
    86
    87print(len(samples))
    88size = samples[0].size
    89
    90with tf.Session() as sess:
    91    tf.global_variables_initializer().run()
    92    for e in range(epochs):
    93        batch_images = samples[e] * -1
    94        batch_noise = np.random.uniform(-1, 1, size=noise_size)
    95
    96        _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
    97        _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
    98
    99    sample_noise = np.random.uniform(-1, 1, size=noise_size)
    100    g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
    101                                         reuse=True), feed_dict={
    102        noise_img:[sample_noise]
    103    })
    104    print(g_logit.size)
    105    g_output = (g_output+1)/2
    106    plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
    107    plt.show()

    训练效果

    在经过了10000次的迭代后,generator网络生成的图片已经接近手写数字零的形状。

      

      本文是对GAN模型的初次探索,在后续GAN模型的系列文章中,笔者将层层深入的去讲解GAN模型复杂的应用。

  • 相关阅读:
    何时使用Hibernate (Gavin King的回答)
    Transaction in ADO.net 2.0
    CollectionClosureMethod in .Net
    如何实现真正的随机数
    如何测试私有方法?(TDD)
    try catch 块的使用原则
    多态小quiz
    A simple way to roll back DB pollution in Test
    一个画图程序的演变
    当前软件开发的反思
  • 原文地址:https://www.cnblogs.com/followees/p/10422792.html
Copyright © 2011-2022 走看看