zoukankan      html  css  js  c++  java
  • Gumbel-Max trick, Perturb and MAP and more

    好久不更博了,今晚得空写一个关于Gumbel-Max trick的短篇。

    Gumbel-Max trick是一个把从multinomial distribution采样转化为一个离散优化问题的trick。更普遍地说,这个trick把一个采样问题转化为一个优化问题。对于某些模型,比如含有环的MRF,得到exact sample是一个困难问题。但对某些特殊的MRF,在上面做离散优化有一些有效的算法。这时候用gumbel-max trick就可能有用。另一方面,采样理论和优化理论各自都独立发展了很长时间,都有成熟的系统,这个trick把两个领域联系起来。对具体问题,如果能分析清楚何时用采样比较有效,何时用优化比较有效,就有希望通过将两个方面共同考虑得到更有效的解法。

    首先看一下最简单的multinomial distribution,这里更方便的表示方法是用一个softmax

    $$p(x=k)=frac{exp(a_k)}{sum_i exp(a_i)}$$

    从这个分布采样的方法一般是计算好各个k的概率,每个k就对应[0,1]区间上的一段。然后采一个[0,1]里均匀分布的样本,看它落在哪一个区间,对应就得到一个x的样本。

    Gumbel-Max trick的核心是Gumbel distribution,它是一个定义在((-infty, +infty))上的概率分布

    $$p(x) = e^{-(x-mu)-e^{-(x-mu)}}$$

    这里(mu)是一个均值参数,另外也可以有一个scale参数,这里省略了。可以证明这是一个已经正则化的分布,首先

    $$egin{align*}int_{-infty}^{+infty} p(x)dx &= int_{-infty}^{+infty} e^{-(x-mu)-e^{-(x-mu)}} dx \ &= int_{-infty}^{+infty} e^{-x-e^{-x}}dxend{align*}$$

    下一步做变量替换(u=e^{-x}),那么(du = -e^{-x}dx = -u dx),从而

    $$egin{align*}int_{-infty}^{+infty} p(x) dx &= int_{+infty}^0 ue^{-u}left(-frac{1}{u} ight)du \ &= int_0^{+infty}e^{-u}du \ &= -e^{-u} vert_0^{+infty}=1end{align*}$$

    也容易得到该分布的累积分布函数CDF为(p(Xle x) = int_{-infty}^x p(t)dt = e^{-e^{-(x-mu)}}).  Gumbel分布的样本可以通过CDF求逆的方法得到。

    Gumbel-Max trick说的是,对于上面提到的softmax distribution,如果对每一个(a_k)加上(mu=0)的Gumbel noise (n_k)变为(a_k'),然后再取argmax{(a_k')},得到的k遵循原来的softmax distribution。也就是说可以通过(a_k)加上Gumbel noise,再取argmax的方法来对softmax distribution进行采样。这个结论可以证明如下。

    令按如此过程得到的采样为x,那么x=k意味着(a_k+n_k ge a_{k'}+n_{k'}, forall k' eq k)。同时由于各(n_k)相互独立,于是有

    $$egin{align*}p(x=k) &= intcdotsint p(a_k + n_k ge a_{k'}+n_{k'}, forall k' eq k)dn_1cdots dn_K \ &= int p(n_k)left[prod_{k' eq k} int p(a_k+n_kge a_{k'}+n_{k'}|n_k)d n_{k'} ight] dn_k \ &= int p(n_k)left[prod_{k' eq k}int p(n_{k'}le n_k + a_k - a_{k'}|n_k)dn_{k'} ight]dn_k\ &= int p(n_k) prod_{k eq k'}e^{-e^{-(n_k+a_k-a_{k'})}} dn_k \ &= int e^{-n_k-e^{-n_k}} e^{-e^{-n_k}sum_{k eq k'} e^{-(a_k - a_{k'})}} dn_k \ &= int e^{-n_k-e^{-n_k}sum_{k'} e^{-(a_k-a_{k'})}}dn_kend{align*}$$

    $$A = log sum_{k'}e^{-(a_k-a_{k'})}=log e^{-a_k}sum_{k'}e^{a_{k'}}=logsum_{k'}e^{a_{k'}}-a_k$$

    则易见(e^{-A})即为softmax distribution中的(p(x=k))。将A替换到上面的变换式中,有

    $$p(x=k)= e^{-A}int e^{-(n_k-A)-e^{-(n_k-A)}}dn_k = e^{-A}$$

    因此通过Gumbel noise再取argmax得到的k的分布就是原来的softmax分布。

    Gumbel-Max trick的应用

    Gumbel-Max trick把一个采样问题转化为了一个优化问题。这个转化有没有用呢?

    从时间复杂度上看,原先的采样过程要产生一个随机数,同时与O(K)个区间进行比较,复杂度为O(K)。现在需要产生K个独立Gumbel随机数,然后再求K个数的max。复杂度仍为O(K),看样子并没有什么效果(不过如果考虑并行化的方法,那么理论上Gumbel-Max trick确实可以把时间降到O(1))。但有的时候,优化问题比采样问题更容易解。

    考虑一个binary pairwise MRF,Gibbs分布形式为(p(X)=frac{1}{Z}expleft(sum_i b_i x_i + sum_{ij}w_{ij}x_i x_j ight)),其中所有(x_iin{-1,1})。当所有二阶系数(w_{ij}ge 0)的时候,求argmax p(X)可以转化为一个min-cut/max-flow问题,从而有非常有效的多项式时间解法(graph-cut)。Computer Vision里面做图像分割的绝大多数方法目前在某个层面上都用到了这个有效地graph-cut解法。但另一方面,想要求归一化常数Z或者想从这个MRF中采样,则被证明是困难的,需要指数量级的时间。

    不过怎么把Gumbel-Max trick用到这里呢?首先,我们要把MRF转化为一个softmax distribution,这一步可以通过穷举所有可能的X,然后对每一个X计算未归一化的概率(exp(phi(X))),其中(phi(X)=sum_i b_i x_i + sum_{ij}w_{ij}x_i x_j)。再下一步对每一个X的(phi(X))加一个Gumbel noise。最后对所有这些(phi(X)+n)求一个argmax。再回过头看一看,这和之前的argmax p(X)并不一样。

    没错,这里将MRF转化为softmax distribution的过程引入了指数数量的项,优化过程也用不到之前提到的高效的graph-cut方法。不过,这个理论结论更重要的意义是给出了MRF的另一种解释:首先把(phi(X))加入一个扰动,然后再在扰动后求一个优化问题的解,这样得到的模型和原MRF的Gibbs分布是一样的。

    这个想法推广开来,如果不把我们的扰动限定在Gumbel noise上,而是任意的在(phi)上加一个扰动n(Perturb),n可以遵循任意合理易采样的概率分布,再在(phi(X))上求一个argmax(MAP),这个过程同样也可以定义一个概率模型。不用Gumbel noise的话这个模型就不再和原来的MRF Gibbs分布等价,但仍然定义了一个概率模型。这个过程定义的概率模型是近几年Perturb and MAP研究的基础。

    应用到最开始的MRF上,(phi(X)=sum_i b_i x_i + sum_{ij}w_{ij}x_i x_j),如果我们对每一个(b_i)和(w_{ij})加上扰动,得到的新(hat{phi})的argmax依然容易求得。这个概率模型就同时有了能表示pairwise interaction,MAP容易求,采样也容易得到的多种好处。

    一般而言,如果一个模型的MAP容易求但采样困难,都可以使用Perturb and MAP。

    今年的NIPS上的一篇Oral文章A* Sampling,把Gumbel-Max trick推广到连续空间,从概率理论里找来Gumbel process,然后提出了一种采样方法结合了Gumbel process的特点和A*搜索,给出了一种从连续分布中采样的方法。这个方法和adaptive rejection sampling很像,有些情况下更优。

    总结:Gumbel-Max trick有时候有点用,但总的来说用处不太大。基于Gumbel-Max trick的Perturb and MAP想法有更大的普适性,应用也更广泛。

  • 相关阅读:
    [国家集训队]拉拉队排练 Manancher_前缀和_快速幂
    高手过愚人节 Manancher模板题_双倍经验
    [模板]manacher算法
    [POI2011]MET-Meteors 整体二分_树状数组_卡常
    [国家集训队]矩阵乘法 整体二分
    三维偏序(陌上花开) CDQ分治
    博客园美化之旅第一天(CSS图层关系,背景相关设置,字体相关设置)
    力扣题目解答自我总结(反转类题目)
    python插件,pycharm基本用法,markdown文本编写,jupyter notebook的基本操作汇总
    关于小程序websocket全套解决方案,Nginx代理wss
  • 原文地址:https://www.cnblogs.com/alexdeblog/p/4118181.html
Copyright © 2011-2022 走看看