zoukankan      html  css  js  c++  java
  • GAN生成式对抗网络(二)——tensorflow代码示例

    代码实现

    当初学习时,主要学习的这个博客 https://xyang35.github.io/2017/08/22/GAN-1/ ,写的挺好的。

    本文目的,用GAN实现最简单的例子,帮助认识GAN算法。

    import numpy as np
    from matplotlib import pyplot as plt
    batch_size = 4
    
    

    2. 真实数据集,我们要通过GAN学习这个数据集,然后生成和他分布规则一样的数据集

    X = np.random.normal(size=(1000, 2))
    A = np.array([[1, 2], [-0.1, 0.5]])
    b = np.array([1, 2])
    X = np.dot(X, A) + b
    
    plt.scatter(X[:, 0], X[:, 1])
    plt.show()
    
    
    # 等会通过这个函数,不断从中取x值,取值数量为batch_size
    def iterate_minibatch(x, batch_size, shuffle=True):
        indices = np.arange(x.shape[0])
        if shuffle:
            np.random.shuffle(indices)
    
        for i in range(0, x.shape[0], batch_size):
            yield x[indices[i:i + batch_size], :]
    
    
    图片名称

    3.封装GAN对象

    包含生成器,判别器

    class GAN(object):
        def __init__(self):
            #初始函数,在这里对初始化模型
        def netG(self, z):
            #生成器模型
        def netD(self, x, reuse=False):
            #判别器模型
        
    

    4.生成器netG

    随意输入的z,通过z*w+b的矩阵运算(全连接运算),返回结果

        def netG(self, z):
            """1-layer fully connected network"""
    
            with tf.variable_scope("generator") as scope:
                W = tf.get_variable(name="g_W", shape=[2, 2],
                                    initializer=tf.contrib.layers.xavier_initializer(),
                                    trainable=True)
                b = tf.get_variable(name="g_b", shape=[2],
                                    initializer=tf.zeros_initializer(),
                                    trainable=True)
                return tf.matmul(z, W) + b
    

    5.判别器nefD

    判别器为三层全连接网络。隐层部分使用tanh激活函数。输出部分没有激活函数

        def netD(self, x, reuse=False):
            """3-layer fully connected network"""
    
            with tf.variable_scope("discriminator") as scope:
                if reuse:
                    scope.reuse_variables()
    
                W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                     initializer=tf.contrib.layers.xavier_initializer(),
                                     trainable=True)
                b1 = tf.get_variable(name="d_b1", shape=[5],
                                     initializer=tf.zeros_initializer(),
                                     trainable=True)
                W2 = tf.get_variable(name="d_W2", shape=[5, 3],
                                     initializer=tf.contrib.layers.xavier_initializer(),
                                     trainable=True)
                b2 = tf.get_variable(name="d_b2", shape=[3],
                                     initializer=tf.zeros_initializer(),
                                     trainable=True)
                W3 = tf.get_variable(name="d_W3", shape=[3, 1],
                                     initializer=tf.contrib.layers.xavier_initializer(),
                                     trainable=True)
                b3 = tf.get_variable(name="d_b3", shape=[1],
                                     initializer=tf.zeros_initializer(),
                                     trainable=True)
    
                layer1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
                layer2 = tf.nn.tanh(tf.matmul(layer1, W2) + b2)
                return tf.matmul(layer2, W3) + b3
    

    6.初始化__init__函数

    def __init__(self):
            # input, output
             #占位变量,等会用来保存随机产生的数,
            self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')   
            #占位变量,真实数据的
            self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')  
    
            # define the network
            #生成器,对随机变量进行加工,产生伪造的数据
            self.fake_x = self.netG(self.z)  
    
             #判别器对真实数据进行判别,返回判别结果
             #reuse=false,表示不是共享变量,需要tensorflow开辟变量地址
            self.real_logits = self.netD(self.x, reuse=False)  
    
            #判别器对伪造数据进行判别,返回判别结果
             #reuse=true,表示是共享变量,复用netD中已有的变量
            self.fake_logits = self.netD(self.fake_x, reuse=True)
    
    
            # define losses
            #判定器的损失值,将真实数据的判定为真实数据,将伪造数据的判断为伪造数据的得分情况
            self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                                 labels=tf.ones_like(self.real_logits))) + 
                          tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                                 labels=tf.zeros_like(self.real_logits)))
            #生成器的生成分数。伪造的数据,别判断器判定为真实数据的得分情况
            self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                                 labels=tf.ones_like(self.real_logits)))
    
            # collect variables
            t_vars = tf.trainable_variables()
            #存放判别器中用到的变量
            self.d_vars = [var for var in t_vars if 'd_' in var.name]
            #存放生成器中用到的变量
            self.g_vars = [var for var in t_vars if 'g_' in var.name]
    

    7.开始训练

    gan = GAN()
    
    #使用随机梯度下降
    d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
    g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)
    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        #将数据循环10次
        for epoch in range(10):
            avg_loss = 0.
            count = 0
            #从真实数据集当中,随机抓取batch_size数量个值
            for x_batch in iterate_minibatch(X, batch_size=batch_size):
                # generate noise z
                #随机变量,数量为batch_size
                z_batch = np.random.normal(size=(4, 2))
    
                # update D network
                 #将拿到的真实数据值和随机生成的数值,喂养给sess,并bp优化一次
                loss_D, _ = sess.run([gan.loss_D, d_optim],
                                     feed_dict={
                                         gan.z: z_batch,
                                         gan.x: x_batch,
                                     })
    
                # update G network
                loss_G, _ = sess.run([gan.loss_G, g_optim],
                                     feed_dict={
                                         gan.z: z_batch,
                                         gan.x: np.zeros(z_batch.shape),  # dummy input
                                     })
    
                avg_loss += loss_D
                count += 1
    
            avg_loss /= count
            #每一个epoch都展示一次生成效果
            z = np.random.normal(size=(100, 2))
            # 随机生成100个数值,0到1000---用来从真实值里面取数据
            excerpt = np.random.randint(1000, size=1000)
            fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                        feed_dict={gan.z: z, gan.x: X[excerpt, :]})
            accuracy = 0.5 * (np.sum(real_logits > 0.5) / 100. + np.sum(fake_logits < 0.5) / 100.)
            print('
    discriminator loss at epoch %d: %f' % (epoch, avg_loss))
            print('
    discriminator accuracy at epoch %d: %f' % (epoch, accuracy))
            plt.scatter(X[:, 0], X[:, 1])
            plt.scatter(fake_x[:, 0], fake_x[:, 1])
            plt.show()
    
    
    

    效果

    完整代码下载

    欢迎转载,转载请注明出处。欢迎沟通交流: panfengqqs@qq.com)

  • 相关阅读:
    confluent-kafka python Producer Consumer实现
    kafka producer.poll producer.flush consumer.poll的区别
    kafka Java创建生产者报错:Invalid partition given with record: 1 is not in the range [0...1)
    Kafka通讯的Java实例
    虚机克隆搭建kafka服务器集群
    kafka报错解决:Broker may not be avaliable
    Kafka+Zookeeper+confluent-kafka搭建
    Kafka学习笔记
    【SpringCloud】 第十篇: 高可用的服务注册中心
    【SpringCloud】 第九篇: 服务链路追踪(Spring Cloud Sleuth)
  • 原文地址:https://www.cnblogs.com/panfengde/p/10020224.html
Copyright © 2011-2022 走看看