zoukankan      html  css  js  c++  java
  • 2019-3-10——生成对抗网络GAN---生成mnist手写数字图像

      1 """
      2 生成对抗网络(GAN,Generative Adversarial Networks)的基本原理很简单:
      3 假设有两个网络,生成网络G和判别网络D。生成网络G接受一个随机的噪声z并生成图片,
      4 记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。
      5 在训练过程中, 生成器努力让生成的图片更加真实从而使得判别器无法辨别图像的真假,
      6 而D的目标就是尽量把分辨出真实图片和生成网络G产出的图片,这个过程就类似于二人博弈,
      7 G和D构成了一个动态的“博弈过程”。随着时间的推移,生成器和判别器在不断地进行对抗,
      8 最终两个网络达到一个动态平衡:生成器生成的图像G(z)接近于真实图像分布,而判别器识别不出真假图像,
      9 即D(G(z))=0.5。最后,我们就可以得到一个生成网络G,用来生成图片。
     10 """
     11 import tensorflow as tf
     12 from matplotlib import pyplot as plt
     13 import os
     14 import numpy as np
     15 from tensorflow.examples.tutorials.mnist import input_data
     16 mnist=input_data.read_data_sets('/MNIST_data/',one_hot=True)
     17 batch_size=64
     18 units_size=128
     19 learning_rate=0.001
     20 epoch=300
     21 smooth=0.1
     22 """定义生成模型"""
     23 def generatorModel(noise_img,units_size,out_size,alpha=0.01):
     24     """生成器的目的是:对于生成的图片,G希望D打上标签1"""
     25     with tf.variable_scope('generator'):
     26         FC=tf.layers.dense(noise_img,units_size)
     27         relu=tf.nn.leaky_relu(FC,alpha)
     28         drop=tf.layers.dropout(relu,rate=0.2)
     29         logits=tf.layers.dense(drop,out_size)
     30         outputs=tf.tanh(logits)
     31         return logits,outputs
     32 
     33 """定义判别模型"""
     34 def discriminatorModel(images,unite_size,alpha=0.01,reuse=False):
     35     """
     36     判别器的目的是:
     37     1. 对于真实图片,D要为其打上标签1
     38     2. 对于生成图片,D要为其打上标签0
     39     """
     40     with tf.variable_scope('discriminator',reuse=reuse):
     41         FC=tf.layers.dense(images,units_size)
     42         relu=tf.nn.leaky_relu(FC,alpha)
     43         logits=tf.layers.dense(relu,1)
     44         outputs=tf.sigmoid(logits)
     45         return logits,outputs
     46 """定义损失函数"""
     47 def loss_fenction(real_logits,fake_logits,smooth):
     48     """生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量"""
     49     G_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
     50         logits=fake_logits,
     51         labels=tf.ones_like(fake_logits)*(1-smooth))
     52     )
     53     """判别器识别生成器产出的图片,希望识别出来的标签为0;tf.zeros_like()创建一个将所有元素都设置为0的张量"""
     54     fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
     55         logits=fake_logits,
     56         labels=tf.zeros_like(fake_logits))
     57     )
     58     """判别器判别真实图片,希望判别出来的标签为1;tf.ones_like()创建一个将所有元素都设置为1的张量"""
     59     real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
     60         logits=real_logits,
     61         labels=tf.ones_like(real_logits)*(1-smooth))
     62     )
     63     D_loss=tf.add(fake_loss,real_loss)
     64     return G_loss,fake_loss,real_loss,D_loss
     65 """定义优化器"""
     66 def optimizer(G_loss,D_loss,learning_rate):
     67     """因为GAN中一共训练了两个网络,所以分别对G和D进行优化"""
     68     train_var=tf.trainable_variables()   #需要训练的变量
     69     G_var=[var for var in train_var if var.name.startswith('generator')]
     70     D_var=[var for var in train_var if var.name.startswith('discriminator')]
     71     G_optimizer=tf.train.AdadeltaOptimizer(learning_rate).minimize(G_loss,var_list=G_var)
     72     D_optimizer=tf.train.AdadeltaOptimizer(learning_rate).minimize(D_loss,var_list=D_var)
     73     return G_optimizer,D_optimizer
     74 """训练"""
     75 def train(mnist):
     76     image_size = mnist.train.images[0].shape[0]
     77     real_images = tf.placeholder(tf.float32,[None,image_size])
     78     fake_images = tf.placeholder(tf.float32,[None,image_size])
     79     """调用生成模型生成假图片G_output"""
     80     G_logits,G_output = generatorModel(fake_images,units_size,image_size)
     81     """D对真实图像的判别"""
     82     real_logits,real_output = discriminatorModel(real_images,units_size)
     83     """D对G生成图像的判别"""
     84     fake_logits,fake_output=discriminatorModel(G_output,units_size,reuse=True)
     85     G_loss,real_loss,fake_loss,D_loss=loss_fenction(real_logits,fake_logits,smooth)
     86     G_optimizer,D_optimizer=optimizer(G_loss,D_loss,learning_rate)
     87 
     88     saver=tf.train.Saver()
     89     step=0
     90     with tf.Session() as session:
     91         session.run(tf.global_variables_initializer())
     92         for Epoch in range(epoch):
     93             for batch_i in range(mnist.train.num_examples//batch_size):
     94                 batch_image,_=mnist.train.next_batch(batch_size)
     95                 """对图像像素进行scale,tanh的输出结果为(-1,1)"""
     96                 batch_image=batch_image*2-1
     97                 """模型的输入噪声"""
     98                 noise_image=np.random.uniform(-1,1,size=(batch_size,image_size))#从均匀分布[-1,1)中随机采样
     99                 session.run(G_optimizer,feed_dict={fake_images:noise_image})
    100                 session.run(D_optimizer,feed_dict={real_images:batch_image,fake_images:noise_image})
    101                 step=step+1
    102                 loss_D= session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
    103                 loss_real= session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
    104                 loss_fake= session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
    105                 loss_G= session.run(G_loss, feed_dict={fake_images: noise_image})
    106             print('epoch:', Epoch, 'loss_D:', loss_D,'loss_real', loss_real,'loss_fake', loss_fake, 'loss_G',  loss_G)
    107             model_path=os.getcwd()+os.sep+"mnist.model"
    108             saver.save(session,model_path,global_step=step)
    109 """定义主函数"""
    110 def main(argv=None):
    111     train(mnist)
    112 if __name__ =='__main__':
    113     tf.app.run()


     1 import tensorflow as tf
     2 import numpy as np
     3 from matplotlib import pyplot as plt
     4 import pickle
     5 import example88_0
     6 
     7 UNITS_SIZE = example88_0.units_size
     8 
     9 
    10 def generatorImage(image_size):
    11     sample_images = tf.placeholder(tf.float32, [None, image_size])
    12     G_logits, G_output = example88_0.generatorModel(sample_images, UNITS_SIZE, image_size)
    13     saver = tf.train.Saver()
    14     with tf.Session() as session:
    15         session.run(tf.global_variables_initializer())
    16         saver.restore(session, tf.train.latest_checkpoint('.'))
    17         sample_noise = np.random.uniform(-1, 1, size=(25, image_size))
    18         samples = session.run(G_output, feed_dict={sample_images: sample_noise})
    19     with open('samples.pkl', 'wb') as f:
    20         pickle.dump(samples, f)
    21 
    22 
    23 def show():
    24     with open('samples.pkl', 'rb') as f:
    25         samples = pickle.load(f)
    26     fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
    27     for ax, image in zip(axes.flatten(), samples):
    28         ax.xaxis.set_visible(False)
    29         ax.yaxis.set_visible(False)
    30         ax.imshow(image.reshape((28, 28)), cmap='Greys_r')
    31     plt.show()
    32 
    33 
    34 def main(argv=None):
    35     image_size = example88_0.mnist.train.images[0].shape[0]
    36     generatorImage(image_size)
    37     show()
    38 
    39 
    40 if __name__ == '__main__':
    41     tf.app.run()

    萍水相逢逢萍水,浮萍之水水浮萍!
  • 相关阅读:
    MongoCola使用教程 1 MongoDB的基本操作和聚合功能
    [教程]MongoDB 从入门到进阶 (TextSearch)
    MongoCola使用教程 2 MongoDB的Replset 初始化和配置
    [教程]MongoDB 从入门到进阶 (aggregation数据库状态)
    [教程]MongoDB 从入门到进阶 (概要 以及 高级索引篇 TimeToLive GeoNear)
    C#多线程函数如何传参数和返回值
    QQ邮箱 C# 发邮件 常见错误异常
    关于c#中的Timer控件的简单用法
    要想使用线程 想去方法 应该传入object 传参
    quartz给任务传参数以及维持任务的状态
  • 原文地址:https://www.cnblogs.com/AIBigTruth/p/10505724.html
Copyright © 2011-2022 走看看