zoukankan      html  css  js  c++  java
  • PRML 7: The EM Algorithm

     1. K-means Clustering: clustering can be regarded as special parametric estimating problems with latent variables, which performs a hard assignment of data points to clusters in contrast to Gaussian Mixture Model introduced later.

      (1) Initialization of $K$ mean vectors;

      (2) E Step (Expectation): assign each point to a cluster by

        $y_n=mathop{argmin}_{C_k}||vec{x}_n-vec{mu}_k||$;

      (3) M Step (Maximization): renew mean vectors by

        $vec{mu}_k^{new}=frac{sum_{n=1}^N I{y_n=C_k}vec{x}_n}{sum_{n=1}^N I{y_n=C_k}}$;

      (4) Repeat (2) and (3) until convergence.

     2. Gaussian Mixture Model: assume $p(vec{x})=sum_{k=1}^Kpi_k Gauss(vec{x} ext{ | }vecmu_k,Sigma_k)$ and  $vec{x}_1,vec{x}_2,...,vec{x}_N$ observed.

      (1) Initiallization of all the parameters;

      (2) E Step (Expectation): calculate the responsibility of $pi_k Gauss(vec{x}_n ext{ | }vec{mu}_k,Sigma_k)$ for $vec{x}_n$ by

        $gamma_{nk}=frac{pi_kcdot Gauss(vec{x}_n ext{ | }vec{mu}_k,Sigma_k)}{sum_{i=1}^Kpi_icdot Gauss(vec{x}_n ext{ | }vec{mu}_i,Sigma_i)}$;

      (3) M Step (Maximization): re-estimate the parameters by

        $vec{mu}_k^{new}=frac{1}{N_k}sum_{n=1}^Ngamma_{nk}cdotvec{x}_n$,

        $Sigma_k^{new}=frac{1}{N_k}sum_{n=1}^Ngamma_{nk}cdot(vec{x}_n-vec{mu}_k^{new})(vec{x}_n-vec{mu}_k^{new})^T$,

        $pi_k^{new}=N_k/N$,  where $N_k=sum_{n=1}^Ngamma_{nk}$;

      (4) Repeat (2) and (3) until convergence.

     3. Forward-Backward Algorithm: Hidden Markov Model (HMM) is a 3-tuple $lambda=(A,B,pi)$, where $Ainmathbb{R}^{N imes N}$ is the state transition matrix, $Binmathbb{R}^{N imes M}$ is the observation probability matrix, and $vec{pi}inmathbb{R}^{N imes 1}$ is the initial state probability vector. HMM assumes that the state probability at any time is only dependent on the previous state, and that the observation probability at any time is only dependent on the current state. It's too computationally expensive to calculate $p(O ext{ | }lambda)=sum_{I}p(O ext{ | }I,lambda)cdot p(I ext{ | }lambda)$, so we either use forward algorithm or backward algorithm to do HMM evaluation instead.

      (1) Forward Algorithm: we calculate $alpha_t(i)=p(o_1,o_2,...,o_t,i_t=q_i ext{ | }lambda)$ by

        $alpha_1(i)=pi_i b_i(o_1)$ and $alpha_t(i)=[sum_{j=1}^Nalpha_{t-1}(j)A_{ji}]b_i(o_t) ext{ }(t>1)$,

        then we get $p(O ext{ | }lambda)=sum_{i=1}^Nalpha_T(i)$;

      (2) Backward Algorithm: we calculate $eta_t(i)=p(o_{t+1},o_{t+2},...,o_{T},i_t=q_i ext{ | }lambda)$ by

        $eta_T(i)=1$ and $eta_t(i)=sum_{j=1}^N A_{ij}b_j(o_{t+1})eta_{t+1}(j) ext{ }(t<T)$,

        then we get $p(O ext{ | }lambda)=sum_{i=1}^Npi_i b_i(o_1)eta_1(i)$;

      (3) Viterbi Decoding: we define $V_t(i)=mathop{max }_{i_1,i_2,...,i_{t-1}}p(o_1,o_2,...,o_t,i_1,i_2,...,i_{t-1},i_t=q_i ext{ | }lambda)$ and calculate it by

        $V_1(j)=pi_j b_j(o_1)$  and  $egin{cases}phi_t(j)=mathop{argmax}_i V_{t-1}(i)A_{ij} \ V_t(j)=A_{phi_t(j),j}V_{t-1}(phi_t(j))cdot b_j(o_t) end{cases}$ for $t>1$;

        Then the sequence likelihood can be calculated by

        $q_T^{*}=mathop{argmax}_j V_T(j)$  and  $q_t^{*}=phi(q_{t+1}^{*})$ for $t<T$.

      A more general concept is a Probabilistic Graphical Model (PGM), which specifies both a factorization of joint distribution and a set of conditional independence relations. A PGM can be either (1) a directed acyclic graph, a.k.a Bayesian Network, or (2) an undirected graph, a.k.a Markov Network. HMMs and neural networks are special cases of Bayesian networks.

     

     4. Baum-Welch Algorithm: we consider $O$ as observable variables and $I$ as latent variables.

      (1) Initiallization of all the parameters;

      (2) E Step (Expectation): use forward-backward algorithm to calculate

        $gamma_t(i)=p(i_t=q_i ext{ | }O,lambda)=frac{alpha_t(i)eta_t(i)}{sum_{j=1}^N alpha_t(i)eta_t(i)}$  and

        $xi_t(i,j)=p(i_t=q_iwedge i_{t+1}=q_j ext{ | }O,lambda)=frac{alpha_t(i)A_{ij}b_j(o_{t+1})eta_{t+1}(j)}{sum_{i=1}^Nsum_{j=1}^Nalpha_t(i)A_{ij}b_j(o_{t+1})eta_{t+1}(j)}$;

      (3) M Step (Maximization): re-estimate the parameters by

        $A_{ij}^{new}=[sum_{t=1}^{T-1}xi_t(i,j)]/[sum_{t=1}^{T-1}gamma_t(i)]$,

        $b_j(k)^{new}=[sum_{t=1}^T I{o_t= u_k}cdotgamma_t(i)]/[sum_{t=1}^Tgamma_t(j)]$,

        $pi_i^{new}=gamma_1(i)$;

      (4) Repeat (2) and (3) until convergence.

     5. EM Algorithm in general: given observed data $X$ and its joint distribution with latent data $Z$ as $p(X,Z ext{ | }vec{ heta})$, where $vec{ heta}$ is unknown parameters, we carry on following steps to maximize the likelihood $p(X ext{ | }vec{ heta})$.

      (1) Initialization of parameters $vec{ heta}^{(0)}$;

      (2) E Step (Expectation): given $vec{ heta}^{(i)}$, we estimate $q(Z)=p(Z ext{ | }X,vec{ heta}^{(i)})$;

      (3) M Step (Maximization): re-estimate $vec{ heta}^{(i+1)}=mathop{argmax}_{vec heta}q(Z)ln{p(X,Z ext{ | }vec{ heta})}$

      (4) Repeat (2) and (3) until convergence.

      For detailed proof of the correctness of this algorithm, please refer to JerryLead's blog.

      In brief, our objective is to maximize $ln{p(X ext{ | }vec{ heta})}=Q(vec{ heta},vec{ heta}^{(i)})-H(vec{ heta},vec{ heta}^{(i)})$,  where

       $Q(vec{ heta},vec{ heta}^{(i)})=sum_Z p(Z ext{ | }X,vec{ heta}^{(i)})ln{p(X,Z ext{ | }vec{ heta})}$,  $H(vec{ heta},vec{ heta}^{(i)})=sum_Z p(Z ext{ | }X,vec{ heta}^{(i)})ln{p(Z ext{ | }X,vec{ heta})}$.

      Since $H(vec{ heta}^{(i+1)},vec{ heta}^{(i)})leq H(vec{ heta}^{(i)},vec{ heta}^{(i)})$ (KL divergence and Jensen's inequality), to make $ln{p(X ext{ | }vec{ heta})}$ larger, it suffices to let $Q(vec{ heta}^{(i+1)},vec{ heta}^{(i)})geq Q(vec{ heta}^{(i)},vec{ heta}^{(i)})$.

    References:

      1. Bishop, Christopher M. Pattern Recognition and Machine Learning [M]. Singapore: Springer, 2006

      2. 李航.  统计学习方法.  北京:清华大学出版社, 2012

  • 相关阅读:
    TCP—为什么是AIMD?
    虚拟机是怎么实现的?
    漫谈linux文件IO
    关于大型网站技术演进的思考
    大公司里怎样开发和部署前端代码
    spawn-fcgi 代码介绍
    使用python传参form-data格式的txt请求接口
    实战scrapy抓取站长图片数据
    通过requests和lxml模块对网站数据进行爬取
    centos7.5下安装jenkins
  • 原文地址:https://www.cnblogs.com/DevinZ/p/4582055.html
Copyright © 2011-2022 走看看