zoukankan      html  css  js  c++  java
  • 指数分布族的变分推断

    指数分布族的变分推理

    回顾一下上节课变分推断讲的内容。
    与MCMC,Gibbs Sampling等基于抽样来估计未知分布的方法不同,变分推断通过一个优化过程来近似隐变量的后验概率分布。给定一组隐变量 (Z={z_1,...,z_M}),其中(M)表示(Z)可以分解为(M)个独立成分。变分推断用一个变分分布(q(Z))估计后验分布(P(Z|X)),通过最大化如下的ELBO来完成近似过程:

    [mathcal{L}(q(Z))=mathbb{E}_{q(Z)}[ln p(X,Z)]-mathbb{E}_{q(Z)}[ln q(Z)] ]

    最优的近似分布为

    [q^*(Z)=argmax_{q(Z)}mathcal{L}(q(Z)) ]

    根据 meanfield假设,我们可以对(q(Z))进行分解:

    [q(Z)=prod_{i=1}^M q_i(Z_i) ]

    与上一堂课所用方法不同,这次我们假设(q_i(Z_i),(i=1,2,...,M))属于指数分布族 ,这些分布的超参数是超参数(lambda_i,(i=1,2,...,M)),这样做的好处是将泛函优化问题简化为了超参数优化问题。

    我们可以用一个概率图模型来表示上一节课的模型:

    由于(Z_1,Z_2,...,Z_M)都是(X)的co-parents,根据概率图模型理论,要估计(Z_i)的后验,我们不仅需要知道(X)的信息,还需要知道({Z_1,...,Z_{i-1},Z_{i+1},...,Z_M})的信息,即(Z_i)的后验同时取决于(X)和其他的(Z_j),可以将其表示为(P(Z_i|X,Z_1,...,Z_{i-1},Z_{i+1},...,Z_M))
    为了数学上表示的方便,我们将({Z_1,...,Z_{i-1},Z_{i+1},...,Z_M})简记为(Z_{-i})。我们假设(Z_i)的后验属于指数分布族:

    [p(Z_i|X,Z_{-i})=h(Z_i)exp(T(Z_i)^Teta(X,Z_{-i})-A_g(eta(X,Z_{-i})) ]

    其中(T(Z_i))(Z_i)的充分统计量,(eta(X,Z_{-i}))是其自然参数,它是关于(X,Z_{-i})的函数,特别地,对于广义线性模型来说这个函数是线性的。

    同样地,我们假设每个近似分布(q_i(Z_i))也是指数分布族:

    [q_i(Z_i|lambda_i)=h(Z_i)exp(T(Z_i)^Tlambda_i-A_ell(lambda_i)) ]

    其中(lambda_i)(Z_i)的超参数。为了方便,记(lambda_{-i}={lambda_1,...,lambda_{i-1},lambda_{i+1},...,lambda_{M}})

    通过引入超参我们可以将问题转化为一个参数优化问题:

    [lambda^*_1,...,lambda^*_M=argmax_{lambda_1,...,lambda_M} mathcal{L}(lambda_1,...,lambda_M) ]

    用贝叶斯公式给目标函数做个变形:

    [mathcal{L}(q(Z))=mathbb{E}_{q(Z)}[ln P(Z_i|X,Z_{-i}) +ln P(Z_{-i}|X) + ln P(X)]-mathbb{E}_{q(Z)}[ln q(Z)] ]

    其参数化形式为:

    [egin{aligned}mathcal{L}(lambda_1,...,lambda_M)&= mathbb{E}_{q(Z|lambda)}[ln P(Z_i|X,Z_{-i}) +ln P(Z_{-i}|X) + ln P(X)]-mathbb{E}_{q(Z|lambda)}[ln q_i(Z_i|lambda_i)]-mathbb{E}_{q(Z|lambda)}[sum_{j eq i} ln q_j(Z_j|lambda_j)]\&= mathbb{E}_{q(Z|lambda)}[ln P(Z_i|X,Z_{-i})]-mathbb{E}_{q(Z|lambda)}[ln q_i(Z_i|lambda_i)]+constend{aligned} ]

    其中(mathbb{E}_{q(Z|lambda)}[ln P(Z_{-i}|X)],mathbb{E}_{q(Z|lambda)}[ln P(X)])(mathbb{E}_{q(Z|lambda)}[sum_{j eq i} ln q_j(Z_j|lambda_j)])都是常数。一开始不太理解为什么(mathbb{E}_{q(Z|lambda)}[ln P(Z_{-i}|X)])是常数,后来经过如下的计算想通了。

    [egin{aligned}mathbb{E}_{q(Z|lambda)}[ln P(Z_{-i}|X)]&= int_{Z_i} igg[int_{Z_{j eq i}} ln P(Z_{-i}|X)prod_{j eq i}q_j(Z_j|lambda_j) dZ_jigg] q_i(Z_i|lambda_i) dZ_i\&=int_{Z_{j eq i}} ln P(Z_{-i}|X)prod_{j eq i}q_j(Z_j|lambda_j) dZ_j int_{Z_i} q_i(Z_i|lambda_i) dZ_i\&=int_{Z_{j eq i}} ln P(Z_{-i}|X)prod_{j eq i}q_j(Z_j|lambda_j) dZ_j =constant(mbox{因为$lambda_{-i}$已知})end{aligned} ]

    代入指数分布式子有

    [ egin{aligned}mathcal{L}(lambda_1,...,lambda_M)&= mathbb{E}_{q(Z|lambda)}[ln P(Z_i|X,Z_{-i})]-mathbb{E}_{q(Z|lambda)}[ln q_i(Z_i|lambda_i)]+const\&= mathbb{E}_{q(Z|lambda)}[ln h(Z_i)]+mathbb{E}_{q(Z|lambda)}[T(Z_i)^Teta(X,Z_{-i})]-underbrace{mathbb{E}_{q(Z|lambda)}[A_g(eta(X,Z_{-i}))]}_{const}\&quad-mathbb{E}_{q(Z|lambda)}[ln h(Z_i)]-mathbb{E}_{q(Z|lambda)}[T(Z_i)^Tlambda_i]+mathbb{E}_{q(Z|lambda)}[A_ell(lambda_i)]+const\&=mathbb{E}_{q(Z|lambda)}[T(Z_i)^T(eta(X,Z_{-i})-lambda_i)]+mathbb{E}_{q(Z|lambda)}[A_ell(lambda_i)]+const\ &=mathbb{E}_{q_i(Z_i|lambda_i)}[T(Z_i)]^Tcdot mathbb{E}_{q_{-i}(Z_{-i}|lambda_{-i})}[eta(X,Z_{-i})-lambda_i]+mathbb{E}_{q(Z|lambda)}[A_ell(lambda_i)]+const\&=A'_ell(lambda_i)^Tmathbb{E}_{q_{-i}(Z_{-i}|lambda_{-i})}[eta(X,Z_{-i})-lambda_i]+mathbb{E}_{q(Z|lambda)}[A_ell(lambda_i)]+const\&=A'_ell(lambda_i)^T(mathbb{E}_{q_{-i}(Z_{-i}|lambda_{-i})}[eta(X,Z_{-i})]-lambda_i)+A_ell(lambda_i)+constend{aligned}]

    接着对(lambda_i)求导:

    [frac{partial mathcal{L}}{partial lambda_i}=A_ell''(lambda_i)(mathbb{E}_{q_{-i}(Z_{-i}|lambda_{-i})}[eta(X,Z_{-i})]-lambda_i)-A'_ell(lambda_i)+A'_ell(lambda_i)=0 ]

    一般来说(A_ell''(lambda_i) eq 0),于是我们有

    [lambda^*_i=mathbb{E}_{q_{-i}(Z_{-i}|lambda_{-i})}[eta(X,Z_{-i})] ]

    有了这个更新式,只要我们遍历(lambda_i)固定其他参数,就可以迭代地对ELBO进行优化,获得后验概率分布的估计。

  • 相关阅读:
    OpenSSL生成证书、密钥
    js中对String去空格
    正则表达式
    webapi调用
    记一次完整的CI持续集成配置过程(.net core+Jenkins+Gitea)
    处理asp.net core连接mysql的一个异常Sequence contains more than one matching element
    asp.net core 3.1+mysql8.0+Hangfire遇到的异常解决记
    升级到asp.net core 3.1遇到的json异常
    了解ASP.NET Core端点路由
    asp.net core 2.2升到3.1遇到的问题小记
  • 原文地址:https://www.cnblogs.com/wacc/p/5750251.html
Copyright © 2011-2022 走看看