zoukankan      html  css  js  c++  java
  • HMM-维特比算法理解与实现(python)

    HMM-前向后向算法理解与实现(python)
    HMM-维特比算法理解与实现(python)

    解码问题

    • 给定观测序列 (O=O_1O_2...O_T),模型 (lambda (A,B,pi)),找到最可能的状态序列 (I^∗={i^∗_1,i^∗_2,...i^∗_T})

    近似算法

    • 在每个时刻 (t) 选择最可能的状态,得到对应的状态序列

    根据HMM-前向后向算法计算时刻 (t) 处于状态 (i^*_t) 的概率:

    [i^∗_t=argmax[gamma_t(i)],t=1,2,...T\ gamma_t(i) = frac{alpha_{i}(t) eta_{i}(t)}{sum_{i=1}^{N} alpha_{i}(t) eta_{i}(t)} ]

    但是无法保证得到的解是全局最优解

    维特比算法

    维特比算法的基础可以概括为下面三点(来源于吴军:数学之美):

    1. 如果概率最大的路径经过篱笆网络的某点,则从起始点到该点的子路径也一定是从开始到该点路径中概率最大的。

    2. 假定第 t 时刻有 k 个状态,从开始到 t 时刻的 k 个状态有 k 条最短路径,而最终的最短路径必然经过其中的一条。

    3. 根据上述性质,在计算第 t+1 时刻的最短路径时,只需要考虑从开始到当前的k个状态值的最短路径和当前状态值到第 t+1 时刻的最短路径即可。如求t=3时的最短路径,等于求t=2时,从起点到当前时刻的所有状态结点的最短路径加上t=2t=3的各节点的最短路径。

    image-20200512214719644

    通俗理解维特比算法,对上面三点加深理解

    假如你从S和E之间找一条最短的路径,最简单的方法就是列出所有可能的路径 ((O(T^N))),选出最小的,显然时间复杂度太高。怎么办?(摘自[3])

    使用维特比算法

    image-20200512223610958

    S到A列的路径有三种可能:S-A1,S-A2,S-A3,如下图

    image-20200513202915071

    S-A1,S-A2,S-A3 中必定有一个属于全局最短路径。继续往右,到了B列

    对B1:

    image-20200513202742395

    会产生3条路径:

    S-A1-B1,S-A2-B1,S-A3-B1
    

    假设S-A3-B1是最短的一条,删掉其他两条。得到

    image-20200513203551041

    对B2:

    image-20200513203743119

    会产生3条路径:

    S-A1-B2,S-A2-B2,S-A3-B2
    

    假设S-A1-B2是最短的一条,删掉其他两条。得到

    image-20200513203847969

    对B3:

    image-20200513204015153

    会产生3条路径:

    S-A1-B3,S-A2-B3,S-A3-B3
    

    假设S-A2-B3是最短的一条,删掉其他两条。得到

    image-20200513204233084

    现在我们看看对B列的每个节点有哪些,回顾维特比算法第二点

    假定第 t 时刻有 k 个状态,从开始到 t 时刻的 k 个状态有 k 条最短路径,而最终的最短路径必然经过其中的一条

    B列有三个节点,所以会有三条最短路径,最终的最短路径一定会经过其中一条。如下图

    image-20200513204552391

    同理,对C列,会得到三条最短路径,如下图

    image-20200513205546888

    到目前为止,仍然无法确定哪条属于全局最短。最后,我们继续看E节点

    image-20200513205723395

    最终发现最短路径为S-A1-B2-C3-E

    数学描述

    在上述过程中,对每一列(每个时刻)会得到对应状态数的最短路径。在数学上如何表达?记录路径的最大概率值 $ delta_t(i)$ 和对应路径经过的节点 (psi_t(i))

    定义在时刻 (t) 状态为 (i) 的所有单条路径中概率最大值为

    [delta_{t}(i)=max _{i_{1}, i_{2}, ldots, i_{t-1}} Pleft(i_{t}=i, i_{t-1}, ldots, i_{1}, o_{t}, ldots, o_{1} | lambda ight), i=1,2, ldots, N ]

    递推公式

    [egin{aligned} delta_{t+1}(i) &=max _{i_{1}, i_{2}, ldots, i_{t}} Pleft(i_{t+1}=i, i_{t}, ldots, i_{1}, o_{t+1}, ldots, o_{1} | lambda ight) \ &=max _{1 leq j leq N}left[delta_{t}(j) a_{j i} ight] b_{i}left(o_{t+1} ight), i=1,2, ldots, N ; t=1,2, ldots, T-1 end{aligned} ]

    定义在时刻 (t) 状态为 (i) 的所有单条路径中,概率最大路径的第 (t - 1) 个节点为

    [psi_{t}(i)=arg max _{1 leq j leq N}left[delta_{t-1}(j) a_{j i} ight], i=1,2, ldots, N ]

    维特比算法步骤:

    ​ step1:初始化

    [egin{aligned}&delta_{1}(i)=pi_{i} b_{i}left(o_{1} ight), i=1,2, ldots, N\&psi_{1}(i)=0, i=1,2, ldots, N\end{aligned} ]

    ​ step2:递推,对 (t=2,3,...,T)

    [delta_{t}(i)=max _{1 leq j leq N}left[delta_{t-1}(j) a_{j i} ight] b_{i}left(o_{t} ight), i=1,2, ldots, N \psi_{t}(i)=arg max _{1 leq j leq N}left[delta_{t-1}(j) a_{j i} ight], i=1,2, ldots, N \ ]

    ​ step3:计算时刻 (T) 最大的 (delta_T(i)) ,即为最可能隐藏状态序列出现的概率。计算时刻(T)最大的 (psi_T(i)) ,即为时刻(T)最可能的隐藏状态。

    [P^{*}=max _{1 leq i leq N} delta_{T}(i) quad i_{T}^{*}=arg max _{1 leq i leq N} delta_{T}(i) ]

    ​ step4:最优路径回溯,对(t=T-1,...,1)

    [i_{t}^{*}=psi_{t+1}left(i_{t+1}^{*} ight)\I^*=(i_{1}^{*},i_{2}^{*},...,i_{T}^{*}) ]

    代码实现

    假设从三个 袋子 {1,2,3}中 取出 4 个球 O={red,white,red,white},模型参数(lambda = (A,B,pi)) 如下,计算状态序列,即取出的球来自哪个袋子

    #状态 1 2 3
    A = [[0.5,0.2,0.3],
    	 [0.3,0.5,0.2],
    	 [0.2,0.3,0.5]]
    
    pi = [0.2,0.4,0.4]
    
    # red white
    B = [[0.5,0.5],
    	 [0.4,0.6],
    	 [0.7,0.3]]
    
    def hmm_viterbi(A,B,pi,O):
        T = len(O)
        N = len(A[0])
        
        delta = [[0]*N for _ in range(T)]
        psi = [[0]*N for _ in range(T)]
        
        #step1: init
        for i in range(N):
            delta[0][i] = pi[i]*B[i][O[0]]
            psi[0][i] = 0
            
        #step2: iter
        for t in range(1,T):
            for i in range(N):
                temp,maxindex = 0,0
                for j in range(N):
                    res = delta[t-1][j]*A[j][i]
                    if res>temp:
                        temp = res
                        maxindex = j
    
                delta[t][i] = temp*B[i][O[t]]#delta
                psi[t][i] = maxindex
    
        #step3: end
        p = max(delta[-1])
        for i in range(N):
            if delta[-1][i] == p:
                i_T = i
    
        #step4:backtrack
        path = [0]*T
        i_t = i_T
        for t in reversed(range(T-1)):
            i_t = psi[t+1][i_t]
            path[t] = i_t
        path[-1] = i_T
        
        return delta,psi,path
    
    A = [[0.5,0.2,0.3],[0.3,0.5,0.2],[0.2,0.3,0.5]]
    B = [[0.5,0.5],[0.4,0.6],[0.7,0.3]]
    pi = [0.2,0.4,0.4]
    O = [0,1,0,1]
    hmm_viterbi(A,B,pi,O)
    

    结果

    image-20200513231008945

    references:

    [1]https://www.cnblogs.com/kaituorensheng/archive/2012/12/04/2802140.html

    [2] https://blog.csdn.net/hudashi/java/article/details/87875259

    [3] https://www.zhihu.com/question/20136144

  • 相关阅读:
    Spring boot unable to determine jdbc url from datasouce
    Unable to create initial connections of pool. spring boot mysql
    spring boot MySQL Public Key Retrieval is not allowed
    spring boot no identifier specified for entity
    Establishing SSL connection without server's identity verification is not recommended
    eclipse unable to start within 45 seconds
    Oracle 数据库,远程访问 ora-12541:TNS:无监听程序
    macOS 下安装tomcat
    在macOS 上添加 JAVA_HOME 环境变量
    Maven2: Missing artifact but jars are in place
  • 原文地址:https://www.cnblogs.com/gongyanzh/p/12878375.html
Copyright © 2011-2022 走看看