zoukankan      html  css  js  c++  java
  • 李宏毅机器学习课程笔记-14.3 Seq2Seq:Tips for Generation

    在训练一个可以产生句子的网络时,有哪些技巧呢?

    Bad Attention

    假如要做video的caption generation,某视频有4个frame,即有4个时刻的图片。

    (alpha^i_t)表示attention weight,其上标表示frame的索引、下标表示时刻的索引。在第1个时刻,产生attention (alpha^1_1,alpha^2_1,alpha^3_1,alpha^4_1),生成第1个word (w_1);在第2个时刻,产生attention (alpha^1_2,alpha^2_2,alpha^3_2,alpha^4_2),生成第2个word (w_2);以此类推……

    这样有时候会产生一些bad attention。比如,如果4个时刻的attention都集中在某一个frame上,就会产生一些奇怪的结果,比如每次生成的word都是相同的。

    good attention需要关注到输入中的每个frame,对每个frame的关注度不能太多也不能太少并且应该是同等级的。那如何实现这种好的attention呢?比如使用正则项(sum_i( au-sum_talpha_t^i))使得每个frame在4个时刻的attention weight之和都接近( au),这个( au)是通过学习得到的,详见《Show, Attend and Tell: Neural Image Caption Generation with Visual Attention》。

    Mismatch between Train and Test

    假如用RNN生成sentence,在训练时模型refer了整个sentence,如果在某一步预测失败可以通过损失函数优化每一次预测;而在测试时,模型的输入是上一步的输出,如果一步出错,可能就会步步出错。这个问题叫做Exposure Bias。那么我们如何解决train和test之间的mismatch呢?

    可以考虑修改训练方法,假设模型现在应该输出A,但现在模型输出了B,即使是错误的输出我们也应该让这个错误的输出作为下一次的输入,这样train和testing就是match的,但实际上这种训练方法是难以见效的。有一个可行的方法是Scheduled Sampling,按一定概率选择模型上一步的输出或者标注作为模型的输入,可以在刚开始时只使用标注作为输入然后慢慢开始使用模型上一步的输出作为输入。

    在一颗庞大的树中搜索一条最优路径时,我们无法穷举所有路径,贪心策略找到的路径也不一定是最优路径。Beam Search就是指在每一步保留最好的几条路径。

    有人说直接把模型输出的分布直接作为下次的输入,但其实这样的结果会比较差,因为这样无法区分出接近的分布。

    Object Level V.S. Component Level

    假如我们要生成一个sentence,那我们就应该关注整个sentence(Object Level)而不仅仅是每个word(Component Level)。

    如果是按照Component Level,那使用Cross Entropy计算损失的话,训练前期loss会下降得很快,但后期loss会下降得很慢(“The dog is is fast”和"The dog is running fast"的loss的差距很小)。

    那有没有一个损失函数可以基于Object Level衡量两个句子间的差异呢?目前是没有的,因为模型输出的分布是离散的,如果微小改变模型参数但保证模型输出的句子相同,那损失函数的输出就是一样的,即微小扰动并没有对loss产生影响。

    那怎么办呢?

    Reinforcement Learning

    可以利用强化学习进行generation,每次生成一个word并不计算reward,知道生成整个sentence后才利用所生成的sentence和标注计算reward,详见《SEQUENCE LEVEL TRAINING WITH RECURRENT NEURAL NETWORKS》。


    Github(github.com):@chouxianyu

    Github Pages(github.io):@臭咸鱼

    知乎(zhihu.com):@臭咸鱼

    博客园(cnblogs.com):@臭咸鱼

    B站(bilibili.com):@绝版臭咸鱼

    微信公众号:@臭咸鱼

    转载请注明出处,欢迎讨论和交流!


  • 相关阅读:
    基于.Net Core的Redis:实现查询附近的地理信息
    基于.Net Core的Redis:基本数据类型及其应用场景与命令行操作
    C# WebClient几种常用方法的用法
    const学习(续)
    C++ const学习
    Unicode
    android studio下使用HAXM android模拟器(x86)加速器
    使用efinance包获取股票数据
    Linux初识
    UWSGI
  • 原文地址:https://www.cnblogs.com/chouxianyu/p/14798190.html
Copyright © 2011-2022 走看看