zoukankan      html  css  js  c++  java
  • DCGAN实现

    DCGAN实现

    代码

    • dcgan.py
    
    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    
    
    import os
    import math
    import argparse
    import cv2
    import numpy as np
    import tensorflow as tf
    
    # DataManager负责提供数据
    class DataManager(object):
    
        def __init__(self, data_dir):
            self.data_dir = data_dir
            self.im_shape = (48, 48, 3)
            self.im_list = self._get_im_names()
            self.batch_size = 64
            self.chunk_size = len(self.im_list) // self.batch_size
        
        def _get_im_names(self):
            if not self.data_dir:
                return np.asarray([])
            im_list = np.asarray(os.listdir(self.data_dir))
            np.random.shuffle(im_list)
            return im_list
    
        def imread(self, im_name):
            im = cv2.imread(os.path.join(self.data_dir, im_name))
            im = cv2.resize(im, self.im_shape[:2])
            im = (im.astype('float32') - 127.5) / 128.0
            return im
    
        def imwrite(self, name, im):
            im = (im * 128.0 + 127.5)
            im = im.astype('uint8')
            cv2.imwrite('./images/%s.jpg' % name, im)
    
        def next_batch(self):
            start = 0
            end = start + self.batch_size
            for i in range(self.chunk_size):
                name_list = self.im_list[start: end]
                batch_im_list = np.asarray([self.imread(im_name) for im_name in name_list])
                yield batch_im_list
                start += self.batch_size
                end += self.batch_size
    
    # 不使用任何其他框架(Keras, Slim), 神经网络中所有的操作都重新封装成一个方法
    class DCGAN(object):
    
        
        def __init__(self, data_dir):
            # 通过data_manager控制数据的输入与输出
            self.data_manager = DataManager(data_dir)
            self.batch_size = self.data_manager.batch_size
            self.im_shape = self.data_manager.im_shape
            self.chunk_size = self.data_manager.chunk_size
            
            # 噪声的长度
            self.z_len = 100
            self.learning_rate = 0.0002
            self.epochs = 100
            self.beta1 = 0.5
            self.sample_size = 64
        
        # 全连接层
        def fc(self, ims, output_size, scope='fc'):
            with tf.variable_scope(scope, reuse=False):
                weights = tf.get_variable('weights', [ims.shape[1], output_size], tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
                biases = tf.get_variable('biases', [1, output_size], initializer=tf.constant_initializer(0.0))
                return tf.matmul(ims, weights) + biases
        
        # 批量均值化
        def batch_norm(self, x, epsilon=1e-5, momentum=0.9, scope='batch_norm', is_training=True):
            with tf.variable_scope(scope, reuse=False):
                return tf.contrib.layers.batch_norm(x, epsilon=epsilon, decay=momentum, updates_collections=None, scale=True, is_training=is_training)
        
        # 卷积层
        def conv2d(self, ims, output_dim, scope='conv2d'):
            with tf.variable_scope(scope, reuse=False):
                # 在Tensorflow中, SAME不是一般人理解的SAME, 在此框架中, 只要知道了输入的维度和stride的大小, 让输入的维度除以stride的大小就是卷积之后的维度
                # 在卷积中, ksize的维度为[height, width, in_channels, out_channels], 注意: 与转置卷积不同
                ksize = [5, 5, ims.shape[-1], output_dim]
                strides = [1, 2, 2, 1]
                weights = tf.get_variable('weights', ksize, tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
                biases = tf.get_variable('biases', [1, 1, 1, output_dim], tf.float32, initializer=tf.constant_initializer(0.0))
                conv = tf.nn.conv2d(ims, weights, strides=strides, padding='SAME') + biases
                return conv
        
        # 转置卷积层
        def deconv2d(self, ims, output_shape, scope='deconv2d'):
            with tf.variable_scope(scope, reuse=False):
                ksize = [5, 5, output_shape[-1], ims.shape[-1]]
                strides = [1, 2, 2, 1]
    
                weights = tf.get_variable('weights', ksize, tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
                biases = tf.get_variable('biases', [1, 1, 1, output_shape[-1]], tf.float32, initializer=tf.constant_initializer(0.0))
    
                deconv = tf.nn.conv2d_transpose(ims, weights, output_shape=output_shape, strides=strides) + biases
                return deconv
        
        # leaky ReLu
        def lrelu(self, x, alpha=0.2):
            return tf.maximum(x, x * alpha)
        
        # 判别器, 比较简单, 就是传统的分类, 不过去掉了池化层, 添加了batch norm
        def discriminator(self, ims, reuse=False):
            with tf.variable_scope('discriminator', reuse=reuse):
                net = self.conv2d(ims, 64, scope='d_conv_1')
                net = self.lrelu(net) 
    
                net = self.conv2d(net, 64 * 2, scope='d_conv_2')
                net = self.batch_norm(net, scope='d_bn_2')
                net = self.lrelu(net)
    
                net = self.conv2d(net, 64 * 4, scope='d_conv_3')
                net = self.batch_norm(net, scope='d_bn_3')
                net = self.lrelu(net)
    
                net = self.conv2d(net, 64 * 8, scope='d_conv_4')
                net = self.batch_norm(net, scope='d_bn_4')
                net = self.lrelu(net)
    
                net = self.fc(tf.reshape(net, [-1, net.shape[1] * net.shape[2] * net.shape[3]]), 1, scope='d_fc_5')
                return tf.nn.sigmoid(net), net
    
        # 生成器, 就是一个解码器, 去掉了池化层, 添加了Bath norm, 左右的结果通过tanh输出
        def generator(self, noise_z, is_training=True):
            with tf.variable_scope('generator', reuse=False):
                # 训练输入的图像为48x48, 反过来计算出各个网络层的图像维度
                net = self.fc(noise_z, 3 * 3 * 64 * 8) 
                net = tf.reshape(net, [-1, 3, 3, 64 * 8])
                net = self.batch_norm(net, scope='g_bn_1', is_training=is_training)
                net = tf.nn.relu(net)
    
                net = self.deconv2d(net, [self.batch_size, 6, 6, 64 * 4], scope='g_conv_2')
                net = self.batch_norm(net, scope='g_bn_2', is_training=is_training)
                net = tf.nn.relu(net)
    
                net = self.deconv2d(net, [self.batch_size, 12, 12, 64 * 2], scope='g_conv_3')
                net = self.batch_norm(net, scope='g_bn_3', is_training=is_training)
                net = tf.nn.relu(net)
    
                net = self.deconv2d(net, [self.batch_size, 24, 24, 64], scope='g_conv_4')
                net = self.batch_norm(net, scope='g_bn_4', is_training=is_training)
                net = tf.nn.relu(net)
    
                net = self.deconv2d(net, [self.batch_size, self.im_shape[0], self.im_shape[1], 3], scope='g_conv_5')
    
                return tf.nn.tanh(net)
               
        def train(self):
            real_ims = tf.placeholder(tf.float32, [self.batch_size, self.im_shape[0], self.im_shape[1], self.im_shape[2]], name='real_ims')
            noise_z = tf.placeholder(tf.float32, [None, self.z_len], name='noise_z')
            
            # Loss functions
            fake_ims = self.generator(noise_z)
            real_prob, real_logits = self.discriminator(real_ims)
            fake_prob, fake_logits = self.discriminator(fake_ims, reuse=True)
    
            real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)))
            fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
            g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)))
            d_loss = real_loss + fake_loss
    
            real_loss_sum = tf.summary.scalar('real_loss', real_loss)
            fake_loss_sum = tf.summary.scalar('fake_loss', fake_loss)
            g_loss_sum = tf.summary.scalar('g_loss', g_loss)
            d_loss_sum = tf.summary.scalar('d_loss', d_loss)
    
            # Optimizer
            train_vars = tf.trainable_variables()
            d_vars = [var for var in train_vars if var.name.startswith('discriminator')]
            g_vars = [var for var in train_vars if var.name.startswith('generator')]
            
            d_global_step = tf.Variable(0, name='d_global_step', trainable=False)
            g_global_step = tf.Variable(0, name='d_global_step', trainable=False)
            d_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1).minimize(d_loss, var_list=d_vars, global_step=d_global_step)
            g_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1).minimize(g_loss, var_list=g_vars, global_step=g_global_step)
    
            saver = tf.train.Saver()
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
            with tf.Session() as sess:
                sess.run(init_op)
                d_merged = tf.summary.merge([d_loss_sum, real_loss_sum, fake_loss_sum])
                g_merged = tf.summary.merge([g_loss_sum])
                writer = tf.summary.FileWriter('./logs', sess.graph)
                ckpt = tf.train.get_checkpoint_state('./checkpoints')
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Restore from model.ckpt')
                else:
                    print('Checkpoint is not found!')
                for epoch in range(self.epochs):
                    batches = self.data_manager.next_batch()
                    for batch in batches:
                        noises = np.random.uniform(-1, 1, size=(self.batch_size, self.z_len)).astype(np.float32)
                        _, d_summary, d_step = sess.run([d_optimizer, d_merged, d_global_step], feed_dict={real_ims: batch, noise_z: noises})
                        sess.run(g_optimizer, feed_dict={noise_z: noises})
                        _, g_summary, g_step = sess.run([g_optimizer, g_merged, g_global_step], feed_dict={noise_z: noises})
    
                        writer.add_summary(d_summary, d_step)
                        writer.add_summary(g_summary, g_step)
    
                        loss_d, loss_real, loss_fake, loss_g = sess.run([d_loss, real_loss, fake_loss, g_loss], feed_dict={real_ims: batch, noise_z: noises})
    
                        print('Epoch: %s, Dis Step: %s, d_loss: %s, real_loss: %s, fake_loss: %s, Gen Step: %s, g_loss: %s' 
                                % (epoch, d_step, loss_d, loss_real, loss_fake, g_step, loss_g))
                        if g_step % 100 == 0:
                            saver.save(sess, './checkpoints/model.ckpt', global_step=g_step)
                            print('G Step %s Save model' % g_step)
    
        def gen(self):
            noise_z = tf.placeholder(tf.float32, [None, self.z_len], name='noise_z')
            sample_ims = self.generator(noise_z, is_training=False)
            saver = tf.train.Saver()
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())
                sample_noise = np.random.uniform(-1, 1, size=(self.sample_size, self.z_len))
                ckpt = tf.train.get_checkpoint_state('./checkpoints')
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    samples = sess.run(sample_ims, feed_dict={noise_z: sample_noise})
                    for idx, sample in enumerate(samples):
                        self.data_manager.imwrite(idx, sample)
                else:
                    print('Checkpoint is not found!')
                    return
    
    
    def data_load_test():
        manager = DataManager()
        batch = manager.next_batch()
        im_list = next(batch)
        for idx, im in enumerate(im_list):
            manager.imwrite(idx, im)
    
    
    def main(argv=None):
        parser = argparse.ArgumentParser()
        parser.add_argument('--train', help='path to dataset')
        parser.add_argument('--gen', help='path to store images')
        args = parser.parse_args()
        if args.train:
            dcgan = DCGAN(args.train)
            dcgan.train()
        elif args.gen:
            if args.gen == 'yes':
                dcgan = DCGAN(None)
                dcgan.gen()
            else:
                print('should be --gen yes')
        else:
            print('...')
    
    if __name__ == '__main__':
        main()
    
    
  • 相关阅读:
    jsp上传下载+SmartUpload插件上传
    《鸟哥的Linux私房菜-基础学习篇(第三版)》(五)
    Activity的启动模式
    重学C++ (十一) OOP面向对象编程(2)
    寒城攻略:Listo 教你用 Swift 写IOS UI 项目计算器
    freemarker写select组件报错总结(二)
    用Radeon RAMDisk在Windows 10中创建关机或重新启动不消失的内存虚拟盘
    JS推断是否为JSON对象及是否存在某字段
    json、js数组真心不是想得那么简单
    javascript正則表達式
  • 原文地址:https://www.cnblogs.com/megachen/p/10803237.html
Copyright © 2011-2022 走看看