zoukankan      html  css  js  c++  java
  • 论文笔记之:SeqGAN: Sequence generative adversarial nets with policy gradient

     SeqGAN: Sequence generative adversarial nets with policy gradient 

    AAAI-2017 

     

    Paperhttps://arxiv.org/abs/1609.05473 

    Offical Tensorflow Codehttps://github.com/LantaoYu/SeqGAN 

    PyTorch Codehttps://github.com/suragnair/seqGAN 

     

    Introduction : 

    产生序列模拟数据来模仿 real data 是无监督学习中非常重要的课题之一。最近, RNN/LSTM 框架在文本生成上取得了非常好的效果,最常见的训练方法是:给定上一个 token,推测当前 token 的最大化似然概率。但是最大似然方法容易受到 “exposure bias” 的干扰:the model generates a sequence iteratively and predicts next token conditioned on its previously predicted ones that may be never observed in the training data。这种 training 和 inference 之间的差异可以招致 accumulatively,随着 sequence 的累计,将会随着 sequence 的增长,变得 prominent。为了解决这个问题,Bengio 在 2015 年提出了 schedule sampling (SS) 的方法,但是又有人说这种方法在某些情况下也会失效。另一个可能的解决方案(the training/inference discrepancy problem)是:在整个产生的序列上构建损失函数,而不是每一个翻译(to build the loss function on the entire generated sequence instead of each trainsition)。但是,在许多其他的应用上,如:poem generation 和 chatbot,一个 task specific loss 无法直接准确的用来评价产生的序列。

     

    GAN 是最近比较热门的研究课题,已经广泛的应用于 CV 的许多课题上,但是,不幸的是,直接用 GAN 来产生 sequence 有两个问题:

    (1),GAN 被设计用来产生 real-valued, continuous data,但是在直接产生 离散的 tokens 的序列,是有问题的,如:text。The reason is that in GANs, the generator starts with random sampling first and then a determistic transform, govermented by the model parameters. As such, the gradient of the loss from D w.r.t. the outputs by G is used to guide the generative model G (paramters) to slightly change the generated value to make it more realistic. 但是,如果基于离散的 tokens 产生的数据,从 D 的 loss 得到的 “slight change” 却不是很有道理,因为可能根本不存在这样的 token 使得这一改变有意义(因为 字典空间是有效的)。

    (2),GAN 仅仅可以提供 score/loss 给整个的 sequence,而对于部分产生的序列,却无法判断目前已经有多好了。(GAN can only give the score/loss for an entire sequence when it has been generated; for a partially generated sequence, it is non-trivial to balance how good as it is now and the future score as the entire sequence. )

      

    本文提出一种思路来解决上述问题,将 序列产生问题 看做是 序列决策问题(consider the sequence generation procedure as a sequential decision making problem)。产生器 被认为是 RL 当中的 agent;状态是 目前已经产生的 tokens,动作是 下一步需要产生的 token。不像 Bahdanau et al. 2016 提出的方法那样需要 a task specific sequence score, such as BLEU in machine translation,为了给出奖励,我们用 discriminator 来评价 sequence,并且反馈评价来引导 generative model 的学习。为了解决 当输出是离散的,梯度无法回传给 generative model 的情况,我们将 generative model 看做是 stochastic parameterized policy。在我们的策略梯度,我们采用 MC 搜索来近似 the state-action value。我们直接用 policy gradient 来训练 policy,很自然的就避免了传统 GAN 中,离散数据的微分困难问题(the differentiation difficulty for discrete data in a conventional GAN)。

     

    Sequence Generative Adversarial Nets :   

      

    As illustrated in Figure 1, the discriminative model Dφ is trained by providing positive examples from the real sequence data and negative examples from the synthetic sequences generated from the generative model Gθ. At the same time, the generative model Gθ is updated by employing a policy gradient and MC search on the basis of the expected end reward received from the discriminative model Dφ. The reward is estimated by the likelihood that it would fool the discriminative model Dφ. The specific formulation is given in the next subsection. 

     

    SeqGAN via PolicyGradient

    Following (Sutton et al. 1999), when there is no intermediate reward, the objective of the generator model (policy) Gθ(yt|Y1:t−1) is to generate a sequence from the start state s0 to maximize its expected end reward: 

      

    其中,RT 是整个序列的奖励,奖励来自于 判别器 Dφ。QGθ Dφ(s,a) is the action-value function of a sequence, i.e. the expected accumulative reward starting from state s, taking action a, and then following policy Gθ. 目标函数的合理性应该是: 从给定的初始状态,产生器的目标是产生一个序列,使得 discriminator 认为是真的。

      

    下一个问题就是:如何如何预测 the action-value function。本文当中,我们采用 REINFORCE algorithm,consider the estimated probability of being real by the discriminator D as the reward。意思是说,如果 判别器 D 认为给定的 fake sequence 是真的,其概率记为 reward,此时:概率越高,reward 越大,这两者是成正比例关系的。正式的来说,我们有:

      

    然而,这个 discriminator 仅仅提供了一个 reward 给一个已经结束的 sequence。因为我们实际上关心的是长期的汇报,在每一个时间步骤,我们不但应该考虑到 previous tokens 的拟合程度,也考虑到 the resulted future outcome。就像是下棋的游戏,玩家有时会放弃即可的奖赏,而为了得到更加长远的奖励。所以,为了评价 the action-value for an intermediate state,我们采用 MC search with a roll-out policy to sample the unkown last T-1 tokens。我们表示一个 N-time 的 MC search 为:

      

    其中,Y^n_{1:t} ={y1, ... , yt} and Y^n_{t+1:T} is sampled based on the roll-out policy and the current state。在我们的实验当中,$Geta$ 也设置为 the generator。为了降低 variance,并且得到更加精确地  action value 的估计值,我们运行 the roll-out policy starting from current state 直到 序列的结束,N times,以得到一批输出样本。所以,我们有:

      

    其中,我们看到 当没有即可奖赏的时候,该函数被迭代的定义为:the next-state value starting from state s' = Y1:t and rolling out to the end。

      

    利用 判别器 D 作为奖赏函数的一个函数是:it can be dynamically updated to further improve the generative model interatively(为了进一步的提升产生式模型,它可以被动态的更新)。一旦我们有了一笔新的 更加 realistic 的产生的序列,我们应该重新训练 the discriminator model as follows:

      

    每次当一个新的判别式模型已经被训练完毕的时候,我们已经准备好来更新 generator。所提出的 基于策略的方法依赖于优化一个参数化的策略,来直接最大化 the long-term reward。目标函数 J 的梯度可以写为:

      

    上述形式是由于 the deterministic state transition and zero intermediate rewards。利用 likelihood ratio,我们构建一种 unbiased estimation of Eq.(6) : 

      

    其中,$Y_{1:t}$ 是观察到的 intermediate state sampled from $G heta$。因为期望 E[*] 可以通过采样的方法进行估计,我们然后更新产生器的参数:

      

    其中,$alpha$ 代表了对应的时刻 h-th step 的学习率。

      

    整体的算法流程如下图所示:

      

    本文首先用 最大似然估计的方法进行预训练 产生器 G,然后用 迭代的进行 G, D 的训练。

      

    然后就是对 G 和 D 的具体结构进行了解释:

    The Generative Model for Sequence:

    用 LSTM 来编码 sentences,然后将其映射到 下一个时刻 token 的概率分布。

    The Discriminative Model for Sequence

    此处的判别器,作者利用 CNN 的方法来进行判别。作者首先将 Word 转为 vector,然后一句话弄成了一个 matrix,然后用多个卷积核,进行特征提取。为了提升精度,作者也加了 highway architecture based on the pooled feature maps. 最后,添加了 fc layer 以及 sigmoid activation 来输出 给定的序列为真的概率(to output the probabiltiy that the input sequence is real)。优化的目标是:最小化 the groundtruth labels 和 the predicted probability 之间的 cross entropy loss。

     

     

  • 相关阅读:
    [Unity3D]脚本中Start()和Awake()的区别
    win7中资源管理器不能显示图片预览的解决方法
    Unity Built-in Shader详解三
    Unity Built-in Shader详解二
    php扩展安装
    ImportError: No module named mysql 报错python引用mysql报错
    运维制度
    ps aux|awk -F'[ ]+' '$3>90{print $2}'|xargs -n1 kill -9
    运维面试必知十五大原理知识点
    resize2fs
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/7900476.html
Copyright © 2011-2022 走看看