zoukankan      html  css  js  c++  java
  • HMM的概率计算问题和预测问题的java实现

    HMM(hidden markov model)可以用于模式识别,李开复老师就是采用了HMM完成了语音识别。

    一下的例子来自于《统计学习方法》

    一个HMM由初始概率分布,状态转移概率分布,观测概率分布确定。并且基于两个假设:

    1 假设任意时刻t的状态只依赖于前一个时刻的状态,与其他时刻的状态和观测序列无关

    2 假设任意时刻的观测只依赖与该市可的马尔科夫的状态,与其他观测,状态无关。

    基于此,HMM有三个基本问题:

    1 概率计算问题,给定模型和观测序列,计算在模型下的观测序列出现的概率

    2 预测问题,已知模型和观测序列,求最有可能的状态序列

    3 学习问题,给出若干个观测序列,估计模型的参数,使得该模型下观测序列概率最大

    由于第三个问题涉及到EM算法,而今天还没有看,所以这里只解决了两个,明天写第三个。

    给出上面两个问题的实际例子

    有三个盒子,每个盒子有红球白球

    盒子   红球  白球

    1   5  5

    2  4  6

    3  7  3

    第一次从这三个盒子中随机取一个盒子的概率为0.2,0.4,0.4

    并且如果上一次抽取的是盒子1那么下一次抽取盒子1的概率为0.5,抽取盒子2的概率为0.2,盒子3的概率为0.3,我们通过一个状态转移矩阵来描述

    0.5  0.2  0.3

    0.3  0.5  0.2

    0.2  0.3  0.5    Aij表示从状态i转移到状态j的概率

    通过以上描述,我们能得到该HMM的模型参数

    状态转移矩阵:

    0.5  0.2  0.3

    0.3  0.5  0.2

    0.2  0.3  0.5

    观测概率分布:

    0.5  0.5

    0.4  0.6

    0.7  0.3 Bij表示第i个状态下观测值为j的概率,这里就是抽到红球和白球的概率

    初始概率:

    0.2,0.4,0.4表示一开始到各个状态的概率

    对于问题1:

    现在我们抽取三次,结果为:红白红,求其出现的概率。

    解决方法:

    采用前向算法

    就是我们从时刻1开始,先计算所有状态下观测为红的概率,接下来再求t2时刻会转移到某个状态的概率和,以此类推

    具体的可以看《统计学习方法》,http://www.cnblogs.com/tornadomeet/archive/2012/03/24/2415583.html这个说的也比较详细

    对于问题2:

    抽三次后,结果为红白红,求被抽到最有可能的盒子的序列

    解决方法:

    这里采用了维特比算法,其实就是很常见的动态规划的算法,和求最短路径一样。如果说t+1时刻的状态序列概率最大,那么t时刻的状态序列也应该是最大的。

    具体可以看《统计学习方法》

      1 import java.io.BufferedReader;
      2 import java.io.FileInputStream;
      3 import java.io.IOException;
      4 import java.io.InputStreamReader;
      5 import java.util.ArrayList;
      6 import java.util.HashMap;
      7 import java.util.Map;
      8 
      9 class Alaph{//Alaph和delta两个一样。。。一开始的时候delta思路错了,后来就不改了
     10     double pro;//用于存放概率
     11     int state;//存放状态值
     12     public String toString(){
     13         return "pro:"+pro+" state:"+state;
     14     }
     15 }
     16 
     17 class Delta{
     18     public double pro;
     19     public int pos;
     20     public String toString(){
     21         return "pro is "+pro+" pos is "+pos;
     22     }
     23 }
     24 
     25 class Utils{
     26     public static ArrayList<ArrayList<Double>> loadMatrix(String filename) throws IOException{//读取数据
     27         ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
     28         FileInputStream fis=new FileInputStream(filename);
     29         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
     30         BufferedReader br=new BufferedReader(isr);
     31         String line="";
     32         
     33         while((line=br.readLine())!=null){
     34             ArrayList<Double> data=new ArrayList<Double>();
     35             String[] s=line.split(" ");
     36             
     37             for(int i=0;i<s.length;i++){
     38                 data.add(Double.parseDouble(s[i]));
     39             }
     40             dataSet.add(data);
     41         }
     42         return  dataSet;
     43     }
     44     
     45     public static ArrayList<Double> loadState(String filename)throws IOException{//读取数据,这个和上面那个很像,
     46         FileInputStream fis=new FileInputStream(filename);
     47         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
     48         BufferedReader br=new BufferedReader(isr);
     49         String line="";
     50         ArrayList<Double> data=new ArrayList<Double>();
     51         while((line=br.readLine())!=null){
     52             
     53             String[] s=line.split(" ");
     54             
     55             for(int i=0;i<s.length;i++){
     56                 data.add(Double.parseDouble(s[i]));
     57             }
     58             
     59         }
     60         return  data;
     61     }
     62     
     63     
     64     public static ArrayList<Double> getColData(ArrayList<ArrayList<Double>> A,int index){//根据index值,获取相应的列的数据,后来好像没什么用到。。。囧
     65         ArrayList<Double> col=new ArrayList<Double>();
     66         for(int i=0;i<A.size();i++){
     67             col.add(A.get(i).get(index));
     68         }
     69         return col;
     70     }
     71     
     72     
     73     public static void showData(ArrayList<ArrayList<Double>> data){//debug的时候用的,打印
     74         for(ArrayList<Double> a:data){
     75             System.out.println(a);
     76         }
     77     }
     78     
     79     public static void showAlaph(ArrayList<Alaph> list){
     80         for(Alaph a:list){
     81             System.out.println(a);
     82         }
     83     }
     84     
     85     public static ArrayList<Alaph> copy(ArrayList<Alaph> list){//复制
     86         ArrayList<Alaph> temp=new ArrayList<Alaph>();
     87         for(Alaph a:list){
     88             Alaph b=new Alaph();
     89             b.pro=a.pro;
     90             b.state=a.state;
     91             temp.add(b);
     92         }
     93         return temp;
     94     }
     95     
     96     public static Delta copyDelta(Delta src){//和上面一样,没有什么用
     97         Delta d=new Delta();
     98         d.pro=src.pro;
     99         d.pos=src.pos;
    100         return d;
    101     }
    102     
    103     public static ArrayList<Delta> copyDeltaList(Delta[] list){//复制
    104         ArrayList<Delta> deltaList=new ArrayList<Delta>();
    105         for(Delta delta:list){
    106             Delta temp=copyDelta(delta);
    107             deltaList.add(temp);
    108         }
    109         return deltaList;
    110     }
    111     
    112     public static void showDeltaList(ArrayList<Delta> list){//debug
    113         for(Delta d:list){
    114             System.out.println(d);
    115         }
    116     }
    117     
    118     public static int getMaxIndex(ArrayList<Delta> list){//求list中值最大的下标
    119         double max=-1.0;
    120         int index=-1;
    121         for(int i=0;i<list.size();i++){
    122             if(list.get(i).pro>max){
    123                 max=list.get(i).pro;
    124                 index=i;
    125             }
    126         }
    127         return index;
    128     }
    129     
    130 }
    131 
    132 
    133 
    134 public class HMM {
    135     public static ArrayList<Alaph> getInitAlaph(ArrayList<Double> initState,ArrayList<ArrayList<Double>> B,int index){//第一步的时候,用于求各个状态下的初始情况
    136         ArrayList<Double> col=Utils.getColData(B,index);
    137         ArrayList<Alaph> alaphSet=new ArrayList<Alaph>();
    138         for(int i=0;i<col.size();i++){
    139             Alaph a=new Alaph();
    140             a.pro=col.get(i)*initState.get(i);//初始情况为初始状态*对应的观测概率矩阵的值
    141             a.state=i;
    142             alaphSet.add(a);
    143         }
    144         return alaphSet;
    145     }
    146     public static ArrayList<Delta> getInitDelta(ArrayList<Double> initState,ArrayList<ArrayList<Double>> B,int index){//和上面一样
    147         ArrayList<Double> col=Utils.getColData(B,index);
    148         ArrayList<Delta> alaphSet=new ArrayList<Delta>();
    149         for(int i=0;i<col.size();i++){
    150             Delta d=new Delta();
    151             d.pro=col.get(i)*initState.get(i);
    152             d.pos=i;
    153             alaphSet.add(d);
    154         }
    155         return alaphSet;
    156     }
    157     
    158     //用于求给定模型和观测序列下求,该模型下的观测序列出现的概率
    159     public static double calProb(ArrayList<ArrayList<Double>> A,ArrayList<ArrayList<Double>> B,ArrayList<Double> initState,String[] observe,Map<String,Integer> map){
    160         int index=map.get(observe[0]);
    161         ArrayList<Alaph> alaphList=getInitAlaph(initState,B,index);//先求第一步的状态概率
    162         for(int i=1;i<observe.length;i++){//对各个观测值进行求解
    163             String s=observe[i];
    164             int tag=map.get(s);
    165             ArrayList<Alaph> temp=Utils.copy(alaphList);
    166             for(Alaph alaph:alaphList){
    167                 int destState=alaph.state;
    168                 double pro=0;
    169                 for(Alaph a:temp){
    170                     int srcState=a.state;
    171                     pro+=a.pro*A.get(srcState).get(destState);
    172                 }
    173                 pro=pro*B.get(destState).get(tag);
    174                 alaph.pro=pro;
    175             }
    176         }
    177         double result=0;
    178         for(Alaph alaph:alaphList){
    179             result+=alaph.pro;
    180         }
    181         return result;
    182     }
    183     
    184     //用于求给定模型和观测序列下,求其最大可能性的状态序列
    185     public static void  decoding(ArrayList<ArrayList<Double>> A,ArrayList<ArrayList<Double>> B,ArrayList<Double> initState,String[] observe,Map<String,Integer> map){
    186         int index=map.get(observe[0]);
    187         
    188         ArrayList<Delta> deltaList=getInitDelta(initState,B,index);
    189         int length=B.size();
    190         Delta maxDeltaList[]=new Delta[B.size()];//用于存放各个状态下的最大概率对应的delta值
    191         ArrayList<ArrayList<Integer>> posList=new ArrayList<ArrayList<Integer>>();//用于存放各个状态下的最佳状态值
    192         
    193         for(int i=0;i<B.size();i++){
    194             ArrayList<Integer> a=new ArrayList<Integer>();
    195             a.add(i);
    196             posList.add(a);
    197         }
    198         
    199         for(int j=1;j<3;j++){
    200             ArrayList<Delta> maxList=new ArrayList<Delta>();
    201             String s=observe[j];
    202             int tag=map.get(s);
    203             for(int i=0;i<B.size();i++){
    204                 Delta max=new Delta();
    205                 double maxPro=-1.0;
    206                 int maxPos=-1;
    207                 int maxIndex=-1;
    208                 for(int k=0;k<deltaList.size();k++){
    209                     Delta delta=deltaList.get(k);
    210                     double pro=delta.pro*A.get(delta.pos).get(i)*B.get(i).get(tag);
    211                     if(pro>maxPro){
    212                         maxPro=pro;
    213                         maxPos=i;
    214                         maxIndex=k;
    215                     }
    216                 }
    217                 max.pro=maxPro;
    218                 max.pos=maxPos;
    219                 maxDeltaList[i]=max;
    220                 posList.get(i).add(maxIndex);
    221             }
    222             
    223             deltaList=Utils.copyDeltaList(maxDeltaList);
    224             System.out.println("  ");
    225         }
    226         
    227         System.out.println(posList.get(Utils.getMaxIndex(deltaList)));
    228         
    229     }
    230     
    231     /**
    232      * @param args
    233      * @throws IOException 
    234      */
    235     public static void main(String[] args) throws IOException {
    236         String dataA="C:/Users/Administrator/Desktop/upload/HMM/A.txt";
    237         String dataB="C:/Users/Administrator/Desktop/upload/HMM/B.txt";
    238         String state="C:/Users/Administrator/Desktop/upload/HMM/init.txt";
    239         ArrayList<ArrayList<Double>> A=Utils.loadMatrix(dataA);
    240         ArrayList<ArrayList<Double>> B=Utils.loadMatrix(dataB);
    241         ArrayList<Double> initState=Utils.loadState(state);
    242         String[] s={"Red","White","Red"};
    243         Map<String,Integer> map=new HashMap();
    244         map.put("Red",0);
    245         map.put("White",1);
    246         double pro=calProb(A,B,initState,s,map);
    247 //        System.out.println("pro is "+pro);
    248         decoding(A,B,initState,s,map);
    249     }
    250 
    251 }
  • 相关阅读:
    Linux常用的命令
    Docker编写镜像 发布个人网站
    Linux安装docker笔记
    单例模式
    Cache一致性协议之MESI
    linux环境搭建单机kafka
    【Ray Tracing The Next Week 超详解】 光线追踪2-4 Perlin noise
    【Ray Tracing The Next Week 超详解】 光线追踪2-3
    【Ray Tracing The Next Week 超详解】 光线追踪2-2
    【Ray Tracing The Next Week 超详解】 光线追踪2-1
  • 原文地址:https://www.cnblogs.com/sunrye/p/4579241.html
Copyright © 2011-2022 走看看