zoukankan      html  css  js  c++  java
  • GAN在seq2seq中的应用 Application to Sequence Generation

    Improving Supervised Seq-to-seq Model
    有监督的 seq2seq ,比如机器翻译、聊天机器人、语音辨识之类的 。

    而 generator 其实就是典型的 seq2seq model ,可以把 GAN 应用到这个任务中。

    RL(human feedback)
    训练目标是,最大化 expected reward。很大的不同是,并没有事先给定的 label,而是人类来判断,生成的 x 好还是不好。
     
    简单介绍一下 policy gradient。更新 encoder 和 generator 的参数来最大化 human 函数的输出。最外层对所有可能的输入 h 求和(weighted sum,因为不同的 h 有不同的采样概率);对一个给定的 h,对所有的可能的 x 求和(因为同样的 seq 输入可能会产生不一样的 seq 输出);求和项为 R(h, x)*P_θ (x | h) ,表示给定一个 h 产生 x 的概率以及对应得到的 reward(整项合起来看,就是 reward 的期望)

    用 sampling 后求平均来近似求期望:

    但是 R_θ 近似后并没有体现 θ(隐藏到 sampling 过程中去了),怎么算梯度?先对 P_θ (x | h) 求梯度,然后分子分母同乘 P_θ (x | h) ,而 grad(P_θ (x | h)) / P_θ (x | h) 就等于 grad(log P_θ (x | h)),所以就在 R_θ 原本的近似项上乘一个 grad(log P_θ (x | h))

    如果是 positive 的 reward(R(hi, xi) > 0), 更新 θ 后  P_θ (xi | hi) 会增加;反之会减小(所以最好人类给的 reward 是有正有负的)

    整个 implement 的过程就如下图所示,注意每次更新 θ 后,都要重新 sampling

    RL 的方法和之前所说的 seq2seq model (based on maximum likelihood)的区别

    GAN(discriminator feedback)
    不再是人给 feedback,而是 discriminator 给 feedback。

    训练流程。训练 D 来分辨 <c, x> pair 到底是来自于 chatbot 还是人类的对话;训练 G 来使得固定的 D 给来自 chatbot 的 (c', x~) 高分。

    仔细想一下,训练 G 的过程中是存在问题的,因为决定 LSTM 在每一个 time step 的 token 的时候实际上做了 sampling (或者取argmax),所以最后的 discriminator 的输出的梯度传不到 generator(不可微)。

    怎么解决?

      1. Gumbel-softmax https://casmls.github.io/general/2017/02/01/GumbelSoftmax.html

      首先需要可以采样,使得离散的概率分布有意义而不是只能取 argmax。对于 n 维概率向量 π,其对应的离散随机变量 xπ 添加 Gumbel 噪声再采样。
      xπ  = argmax(log(πi) + Gi)
      其中,G是独立同分布的标准 Gumbel 分布的随机变量,cdf 为 F(x) = exp(-exp(-x))。为了要可微,用 softmax 代替 argmax(因为 argmax 不可微,所以光滑地逼近),G可以通过 Gumbel 分布求逆,从均匀分布中生成 Gi = -log(-log(Ui)),Ui ~ U(0, 1) 
      

      2. Continuous Input for Discriminator 

      避免 sampling 过程,直接把每一个 time step 的 word distribution 当作 discriminator 的输入。

       

      这样做有问题吗?明显有,real sentence 的 word distribution 就是每个词 one-hot 的,而 generated sentence 的 word distribution 本质上就不会是 1-of-N,这样 discriminator 很容易就能分辨了,而且判断准则没有在考虑语义了(直接看是不是 one-hot 就行了)。

      

      3. Reinforcement Learning

       

      把 discriminator 的 output 看作是 reward:

        • Update generator to increase discriminator = to get maximum reward    
        • Using the formulation of policy gradient, replace reward  R(c, x) with discriminator output D(c, x)
      
      和典型的 RL 不同的是,discriminator 参数是要 update 的,还是要输入给 discriminator 现在 chatbot 产生的对话和人类的对话,训练 discriminator 来分辨。
      
     
    Unsupervised Seq-to-seq Model
     
    Text Style Transfer
    用 cycle GAN 来实现,训练两个 GAN,实现两个 domain 的互相转。仍旧要面对 generator 的输出要 sampling 的情况,选择上述第二种解决方案,就是连续化。直接用 word embedding 的向量。

    也可以用映射到 common space 的方法,sampling 后离散化的问题,可以用一个新的技巧解决:把 decoder LSTM 的 hidden layer 当作 discriminator 的输入,就是连续的了。

     
     
    Unsupervised Abstractive Summarization
     
    Unsupervised Translation
  • 相关阅读:
    NS3系列—4———NS3中文教程5:Tweaking NS3
    NS3系列—3———NS3中文:4 概念描述
    NS3系列—2———NS3笔录
    NS3系列—1———NS3中文教程:3下载及编译软件
    How to speed my too-slow ssh login?
    Linux bridge
    使用 GDB 和 KVM 调试 Linux 内核与模块
    How To Set Up A Serial Port Between Two Virtual Machines In VirtualBox
    Linux内核调试环境搭建(基于ubuntu12.04)
    Android
  • 原文地址:https://www.cnblogs.com/chaojunwang-ml/p/11453946.html
Copyright © 2011-2022 走看看