zoukankan      html  css  js  c++  java
  • tfgan折腾笔记(一):核心功能简要概述

    tfgan是什么?

    tfgan是tensorflow团队开发出的一个专门用于训练各种GAN的轻量级库,它是基于tensorflow开发的,所以兼容于tensorflow。在tensorflow1.x版本中,tfgan存在于tensorflow.contrib中,作为一个小模块供使用者调用。在更新到tensorflow2.0版本后,tfgan成为一个独立的库。可使用:

    pip install tensorflow-gan
    

     进行下载安装,并在python中使用以下语句导入这个包:

    import tensorflow_gan as tfgan
    

     可以使用tfgan对目前流行的GAN模型进行训练。并且,tfgan维护团队也会不断更新tfgan,使得其可以对论文中最新提出的GAN模型进行训练。

    tfgan项目托管在github中,点击这里可以查看tfgan在github中托管的源代码及其官方教程与示例。

    tfgan核心功能

    tfgan的中函数的功能主要集中在基于tensorflow的LOSS函数、优化器、训练迭代的封装,以及对GAN模型的评估。其它的如数据集的输入、生成器和判别器模型的结构以及推断过程则需要通过调用tensorflow函数自己编写。即使这样,tfgan也极大的简化了GAN的训练与实现。接下来就针对tfgan中的几个核心功能对应的函数进行一个预览,以便对tfgan有一个初步印象。具体的用法将在后续文章中详细说明注意:以下的代码中的函数均为调用,而不是函数原型

    tfgan核心函数示例

    ·初始化模型

    以Original-GAN为例进行说明,其它的例如C-GAN, info-GAN, Cycle-GAN等的情况与此处略有不同,在后续文章中会有具体说明

    gan_model = tfgan.gan_model(
        generator_fn=generator,
        discriminator_fn=discriminator,
        real_data=images,
        generator_inputs=tf.random.normal(
            [batch_size, noise_dims]
        )
    )
    

     在tfgan中,调用gan_model函数以创建Original-GAN网络模型,其主要参数包含4个,以下详细说明:

    generator_fn:需要先自定义一个生成器函数,函数中定义判别器网络模型,并将函数名称作为参数传入。定义的生成器函数的接口应当符合如下格式:

    def generator(noise, weight_decay=2.5e-5, is_training=True):
        '''GAN Generator.
    
        Args:
            noise: A 2D Tensor of shape [batch size, noise dim].
            weight_decay: The value of the l2 weight decay.
            is_training: If `True`, batch norm uses batch statistics. If `False`, batch
                norm uses the exponential moving average collected from population
                statistics.
    
        Returns:
            A generated image.
        '''
    

     discriminator_fn:同样,需要首先自定义一个判别器函数,函数中定义判别器网络模型。并将函数名称作为参数传入。定义的判别器函数的接口应当符合如下格式:

    def discriminator(img, unused_conditioning, weight_decay=2.5e-5):
        '''GAN discriminator.
    
        Args:
            img: Real or generated MNIST digits. Should be in the range [-1, 1].
            unuseed_conditioning: The TFGAN API can help with conditional GANs, which
                would require extra `condition` information to both the generator and the
                discriminator. Since this example is not conditional, we do not use this
                argument.
            weight_decay: The L2 weight decay.
    
        Returns:
            Logits for the probability that the image is real.
        '''
    

     real_data:真实图像。一个batch的Tensor格式。

    generator_inputs:输入GAN的随机噪声,一般通过tf.random.normal()函数获得。

    ·指定损失函数

    使用gan_loss函数指定训练GAN时所需要的损失函数,若调用形式如下所示,使用默认的损失函数:

    gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
    

     gan_model:上一步初始化模型时的返回值。

    add_summaries:是否添加损失的总结。tfgan在训练时,会自动生成tensorboard的日志信息(日志的位置将在最后一步“gan_train”函数中指定,tensorboard是一个适配于tensorflow的训练过程可视化工具),若为True,将添加loss的信息到日志中。

    或者使用tfgan中内置的其它loss函数,下面的函数调用时就使用了带权重惩罚的W距离。或者可以自己自定义loss函数,此处不再详述。

    gan_loss = tfgan.gan_loss(
        gan_model,
        generator_loss_fn=tfgan.losses.modified_generator_loss,
        discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
        mutual_information_penalty_weight=1.0,
        add_summaries=True
    )
    

    ·指定优化器

    train_ops = tfgan.gan_train_ops(
        gan_model,
        gan_loss,
        generator_optimizer=tf.compat.v1.train.AdamOptimizer(3e-3, 0.5),
        discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(3e-4, 0.5),
        summarize_gradients=True
    )
    

     优化器一般需要传递4个参数:

    gan_model:第一步调用tfgan.gan_model的返回值

    gan_loss:第二步调用tfgan.gan_loss的返回值

    generator_optimizer:指定生成器的优化器

    discriminator_optimizer:指定判别器的优化器

    summarize_gradients:添加梯度的总结

    ·开始训练

    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=20)
        ],
        logdir=train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60
    )
    

     参数解释:

    train_ops:上一步函数的返回值

    hooks:tf.train.SessionRunHook类型的回调函数,用列表形式封装。此处的函数将在每次训练迭代时调用

    logdir:tfgan自动将建立好的网络模型以及训练过程的参数变化存储下来,此参数即为存储的位置

    get_hooks_fn:G和D的训练方式,get_joint_train_hooks()意为进行一次G+D的参数更新,然后再单独进行一次D的参数更新。以此为一个迭代周期。

    save_checkpoint_secs:训练过程中参数存储周期,此处设置为60s存储一次网络参数。

    调用gan_train函数后,训练开始进行。

    使用tfgan进行GAN网络训练步骤:

    1.定义Generator与Discriminator网络模型;

    2.加载训练集数据为batch形式;

    3.调用gan_model函数以初始化网络模型;

    4.调用gan_loss函数以指定损失函数;

    5.调用gan_train_ops函数以指定优化器;

    6.调用gan_train函数开始训练;

    7.训练完毕后,tfgan自动将网络模型及参数以及训练过程的总结(summarise)存储在硬盘中。

    使用tfgan进行推断的步骤:

    1.从tfgan保存的日志中加载网络模型及参数;

    2.加载测试数据;

    3.将数据传入(feed)网络,得到结果。

  • 相关阅读:
    MapReduce in MongoDB
    MongoDB的一些基本操作
    谈谈NOSQL
    Java中的反射(1)
    Mybatisの常见面试题
    关于Lombok和自动生成get set方法
    订Pizza(Java)
    美化Div的边框
    爱,死亡和机器人(Love,Death&Robots)
    session与cookie的介绍和两者的区别之其相互的关系
  • 原文地址:https://www.cnblogs.com/WongWai95/p/TFGAN-ZHE-TENG-BI-JI-1.html
Copyright © 2011-2022 走看看