zoukankan      html  css  js  c++  java
  • 不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码:

    import scipy.misc
    import numpy as np
    
    # 保存图片函数
    def save_images(images, size, path):
        
        """
        Save the samples images
        The best size number is
                int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
        example:
            The batch_size is 64, then the size is recommended [8, 8]
            The batch_size is 32, then the size is recommended [6, 6]
        """
    
        # 图片归一化,主要用于生成器输出是 tanh 形式的归一化
        img = (images + 1.0) / 2.0
        h, w = img.shape[1], img.shape[2]
    
        # 产生一个大画布,用来保存生成的 batch_size 个图像
        merge_img = np.zeros((h * size[0], w * size[1], 3))
    
        # 循环使得画布特定地方值为某一幅图像的值
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            merge_img[j*h:j*h+h, i*w:i*w+w, :] = image
        
        # 保存画布
        return scipy.misc.imsave(path, merge_img)

    这个函数的作用是在训练的过程中保存采样生成的图片。

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定义生成器,判别器和训练过程中的采样网络,在 model.py 输入如下代码:

    import tensorflow as tf
    from ops import *
    
    BATCH_SIZE = 64
    
    # 定义生成器
    def generator(z, y, train = True):
        # y 是一个 [BATCH_SIZE, 10] 维的向量,把 y 转成四维张量
        yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
        # 把 y 作为约束条件和 z 拼接起来
        z = tf.concat(1, [z, y], name = 'z_concat_y')
        # 经过一个全连接,BN 和激活层 ReLu
        h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'), 
                                         is_train = train, name = 'g_bn1'))
        # 把约束条件和上一层拼接起来
        h1 = tf.concat(1, [h1, y], name = 'active1_concat_y')
        
        h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'), 
                                         is_train = train, name = 'g_bn2'))
        h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
        # 把约束条件和上一层拼接起来
        h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y')
    
        h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128], 
                                                  name = 'g_deconv2d3'), 
                                                  is_train = train, name = 'g_bn3'))
        h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y')
        
        # 经过一个 sigmoid 函数把值归一化为 0~1 之间,
        h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1], 
                                    name = 'g_deconv2d4'), name = 'generate_image')
        
        return h4
    
    # 定义判别器    
    def discriminator(image, y, reuse = False):
        
        # 因为真实数据和生成数据都要经过判别器,所以需要指定 reuse 是否可用
        if reuse:
            tf.get_variable_scope().reuse_variables()
    
        # 同生成器一样,判别器也需要把约束条件串联进来
        yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
        x = conv_cond_concat(image, yb, name = 'image_concat_y')
        
        # 卷积,激活,串联条件。
        h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
        h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb')
        
        h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'), 
                                    name = 'd_bn2'), name = 'lrelu2')
        h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
        h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y')
    
        h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'), 
                                    name = 'd_bn3'), name = 'lrelu3')
        h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y')
        
        # 全连接层,输出以为 loss 值
        h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid')
        
        return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4
        
    # 定义训练过程中的采样函数    
    def sampler(z, y, train = True):
        tf.get_variable_scope().reuse_variables()
        return generator(z, y, train = train)

    可以看到,生成器由 7 × 7  变为 14 × 14 再变为 28 × 28大小,每一层都加入了约束条件 y,完美的诠释了论文所给出的网络,之所以要加入 is_train 参数,是由于 Batch_norm 层中训练和测试的时候的过程是不同的,用这个参数区分训练和测试,生成器的最后一层,用了一个 sigmoid 函数把值归一化到 0~1 之间,如果是不加约束的网络,则用 tanh 函数,所以在 save_images 函数中要用到语句:img = (images + 1.0) / 2.0。

    sampler 函数的作用是在训练过程中对生成器生成的图片进行采样,所以这个函数必须指定 reuse 可用,关于 reuse 说明,请看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。

    参考资料:

    1. https://github.com/carpedm20/DCGAN-tensorflow

  • 相关阅读:
    moment获取天的23时59分59秒可以用moment().endOf(String),以及获取天的0时0分0秒可以用moment().startOf('day')
    vue 去除输入框首位的空格
    管道
    事件广播
    iview在子组件中调用父组件的方法
    ZOJ 3430 Detect the Virus(AC自动机)
    HDU 3065 病毒侵袭持续中(AC自动机)
    HDU 2896 病毒侵袭(AC自动机)
    HDU 2222 Keywords Search(AC自动机)
    shell常用命令
  • 原文地址:https://www.cnblogs.com/Charles-Wan/p/6328379.html
Copyright © 2011-2022 走看看