zoukankan      html  css  js  c++  java
  • 隐马尔科夫模型

    特征向量:跟踪框位置相对轨迹中心的比值,角度,速度。

    马尔科夫模型:

    State Sequence, q1 q2 ...... qT

    t个状态之间的转移可见,则这个时间序列的概率是πq1 × aq1q2 × ...... × aqT-1qT

    隐马尔科夫模型:

    状态不可见(隐藏),只能从观察值推测出,所以由观察值推测该时刻的状态有个观察值概率b.

    πq1 × bq1( o1 ) × aq1q2 × bq2( o2 ) × ...... × aqT-1qT × bqT( oT ),

    三个问题:

    1。评价问题——前向后向算法

    计算所有可能路径的概率。

    前向:计算时向前运用结合律,意义??

    2。解码问题

    计算所有可能路径中概率最大的一条路径

    3。学习问题

    給定一個觀察序列o1 o2 ...... oT,更新ABΠ使得Evaluation Problem算得的機率盡量大。

    程序:

    解码问题:

    double CHMM::Decode(vector<double*>& seq, vector<int>& state)
    {
        // Viterbi
        int size = (int)seq.size();
        double* lastLogP = new double[m_stateNum];
        double* currLogP = new double[m_stateNum];
        int** path = new int*[size];
        int i,j,t;
    
        // Init
        path[0] = new int[m_stateNum];
        for ( i = 0; i < m_stateNum; i++)
        {
            currLogP[i] = LogProb(m_stateInit[i]) + 
                LogProb(m_stateModel[i]->GetProbability(seq[0]));
            path[0][i] = -1;
        }
    
        // Recursion
        for ( t = 1; t < size; t++)  //对每一个观测,求属于每个状态的当前最大累加概率
        {
            path[t] = new int[m_stateNum];
            double* temp = lastLogP;
            lastLogP = currLogP;
            currLogP = temp;
    
            for ( i = 0; i < m_stateNum; i++)
            {
                currLogP[i] = -1e308;
                // Searching the max for last state.
                for ( j = 0; j < m_stateNum; j++)
                {
                    double l = lastLogP[j] + LogProb(m_stateTran[j][i]);
                    if (l > currLogP[i])
                    {
                        currLogP[i] = l;
                        path[t][i] = j;
                    }
                }
                currLogP[i] += LogProb(m_stateModel[i]->GetProbability(seq[t]));
            }
        }
    
        // Termination
        int finalState = 0;
        double prob = -1e308;
        for ( i = 0; i < m_stateNum; i++)
        {
            if (currLogP[i] > prob)
            {
                prob = currLogP[i];
                finalState = i;
            }
        }
    
        // Decode
        state.push_back(finalState);
        for ( t = size - 2; t >=0; t--)
        {
            int stateIndex = path[t+1][state.back()];
            state.push_back(stateIndex);
        }
    
        // Reverse the state list
        reverse(state.begin(), state.end());
    
        // Clean up
        delete[] lastLogP;
        delete[] currLogP;
        for ( i = 0; i < size; i++)
        {
            delete[] path[i];
        }
        delete[] path;
    
        prob = exp(prob / size);
        return prob;
    }

    训练问题:

    init:把所有样本的每个序列的特征值平均分给每个状态,然后用混合高斯模型表征每个状态。

    train:先用decode解码,得到该序列一条概率最大的路径,对路径上所有出现的状态转移进行累积,最后两个状态之间的转移数除以该状态转移到其它所有状态移总数,得到的比值即为状态转移概率和初始状态概率。迭代直到误差小于0.001

    /*    SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...*/
    void CHMM::Init(const char* sampleFileName)
    {
        //--- Debug ---//
        //DumpSampleFile(sampleFileName);
    
        // Check the sample file
        ifstream sampleFile(sampleFileName, ios_base::binary);
        assert(sampleFile);
    
        int i,j;
        int size = 0;
        int dim = 0;
        sampleFile.read((char*)&size, sizeof(int));  //读样本数
        sampleFile.read((char*)&dim, sizeof(int));   //读取特征维数
        assert(size >= 3);
        assert(dim == m_stateModel[0]->GetDimNum());
    
        //这里为从左到右型,第一个状态的初始概率为0.5, 其他状态的初始概率之和为0.5,
        //每个状态到自身的转移概率为0.5, 到下一个状态的转移概率为0.5.
        //此处的初始化主要是对混合高斯模型进行初始化
        for ( i = 0; i < m_stateNum; i++)
        {
            // The initial probabilities
            if(i == 0)
                m_stateInit[i] = 0.5;
            else
                m_stateInit[i] = 0.5 / float(m_stateNum-1);
    
            // The transition probabilities
            for ( j = 0; j <= m_stateNum; j++)
            {
                if((i == j)||( j == i+1))
                    m_stateTran[i][j] = 0.5;
            }
        }
    
        vector<double*> *gaussseq;
        gaussseq= new vector<double*>[m_stateNum];
    
        for ( i = 0; i < size; i++)//处理每个样本产生的特征序列
        {
            int seq_size = 0;
            sampleFile.read((char*)&seq_size, sizeof(int));  //序列的长度
    
            double r = float(seq_size)/float(m_stateNum); //每个状态有r个dim维的特征向量
            for ( j = 0; j < seq_size; j++)
            {
                double* x = new double[dim];
                sampleFile.read((char*)x, sizeof(double) * dim);
                //把特征序列平均分配给每个状态
                gaussseq[int(j/r)].push_back(x);
            }
        }
    
        char** stateFileName = new char*[m_stateNum];
        ofstream* stateFile = new ofstream[m_stateNum];
        int* stateDataSize = new int[m_stateNum];
    
        for ( i = 0; i < m_stateNum; i++)
        {
            stateFileName[i] = new char[20];
            ostrstream str(stateFileName[i], 20);
            str << "chmm_s" << i << ".tmp" << '';
        }
        //将每个状态的特征序列保存到文件中,并初始化GMM
        for ( i = 0; i < m_stateNum; i++)
        {
            stateFile[i].open(stateFileName[i], ios_base::binary);
            stateDataSize[i] = gaussseq[i].size();
            stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
            stateFile[i].write((char*)&dim, sizeof(int));
            double* x = new double[dim];
            for( j = 0; j < stateDataSize[i]; j++)
            {
                x = (double*)gaussseq[i].at(j);
                stateFile[i].write((char*)x, sizeof(double) * dim);
            }
            delete x;
            stateFile[i].close();
            //使用Kmeans算法初始化状态的每个GMM
            m_stateModel[i]->Train_Lee(stateFileName[i],i);
            gaussseq[i].clear();
        }
    
        for ( i = 0; i < m_stateNum; i++)
            delete[] stateFileName[i];
    
        delete[] stateFileName;
        delete[] stateFile;
        delete[] stateDataSize;
        delete[] gaussseq;
    }
    
    /*    SampleFile: <size><dim><seq_size><seq_data>...<seq_size>...*/
    void CHMM::Train(const char* sampleFileName)
    {
        Init(sampleFileName);
    
        //--- Debug ---//
        DumpSampleFile(sampleFileName);
    
        // Check the sample file
        ifstream sampleFile(sampleFileName, ios_base::binary);
        assert(sampleFile);
        int i,j;
    
        int size = 0;
        int dim = 0;
        sampleFile.read((char*)&size, sizeof(int));
        sampleFile.read((char*)&dim, sizeof(int));
        assert(size >= 3);
        assert(dim == m_stateModel[0]->GetDimNum());
    
        // Buffer for new model
        int* stateInitNum = new int[m_stateNum];
        int** stateTranNum = new int*[m_stateNum];
        char** stateFileName = new char*[m_stateNum];
        ofstream* stateFile = new ofstream[m_stateNum];
        int* stateDataSize = new int[m_stateNum];
    
        for ( i = 0; i < m_stateNum; i++)
        {
            stateTranNum[i] = new int[m_stateNum + 1];
            stateFileName[i] = new char[20];
            ostrstream str(stateFileName[i], 20);
            str << "chmm_s" << i << ".tmp" << '';
        }
    
        bool loop = true;
        double currL = 0;
        double lastL = 0;
        int iterNum = 0; //迭代次数
        int unchanged = 0;
        vector<int> state;
        vector<double*> seq;
    
        while (loop)
        {
            lastL = currL;
            currL = 0;
    
            // Clear buffer and open temp data files
            for ( i = 0; i < m_stateNum; i++)
            {
                stateDataSize[i] = 0;
                stateFile[i].open(stateFileName[i], ios_base::binary);
                stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
                stateFile[i].write((char*)&dim, sizeof(int));
                memset(stateTranNum[i], 0, sizeof(int) * (m_stateNum + 1));
            }
            memset(stateInitNum, 0, sizeof(int) * m_stateNum);
    
            // Predict: obtain the best path
            sampleFile.seekg(sizeof(int) * 2, ios_base::beg);
            for ( i = 0; i < size; i++)
            {
                int seq_size = 0;
                sampleFile.read((char*)&seq_size, sizeof(int));
    
                for ( j = 0; j < seq_size; j++)
                {
                    double* x = new double[dim];
                    sampleFile.read((char*)x, sizeof(double) * dim);
                    seq.push_back(x);
                }
    
                currL += LogProb(Decode(seq, state)); //Viterbi解码
    
                stateInitNum[state[0]]++;
                for ( j = 0; j < seq_size; j++)
                {
                    stateFile[state[j]].write((char*)seq[j], sizeof(double) * dim);
                    stateDataSize[state[j]]++;
                    if (j > 0)
                    {
                        stateTranNum[state[j-1]][state[j]]++;
                    }
                }
                stateTranNum[state[j-1]][m_stateNum]++; // Final state
    
                for ( j = 0; j < seq_size; j++)
                {
                    delete[] seq[j];
                }
                state.clear();
                seq.clear();
            }
            currL /= size;
    
            // Close temp data files
            for ( i = 0; i < m_stateNum; i++)
            {
                stateFile[i].seekp(0, ios_base::beg);
                stateFile[i].write((char*)&stateDataSize[i], sizeof(int));
                stateFile[i].close();
            }
    
            // Reestimate: stateModel, stateInit, stateTran
            int count = 0;
            for ( j = 0; j < m_stateNum; j++)
            {
                if (stateDataSize[j] > m_stateModel[j]->GetMixNum() * 2)
                {
                    m_stateModel[j]->DumpSampleFile(stateFileName[j]);
                    m_stateModel[j]->Train_Lee(stateFileName[j],j);
                }
                count += stateInitNum[j];
            }
            for ( j = 0; j < m_stateNum; j++)
            {
                m_stateInit[j] = 1.0 * stateInitNum[j] / count;
            }
            for ( i = 0; i < m_stateNum; i++)
            {
                count = 0;
                for ( j = 0; j < m_stateNum + 1; j++)
                {
                    count += stateTranNum[i][j];
                }
                if (count > 0)
                {
                    for ( j = 0; j < m_stateNum + 1; j++)
                    {
                        m_stateTran[i][j] = 1.0 * stateTranNum[i][j] / count;
                    }
                }
            }
            // Terminal conditions
            iterNum++;
            unchanged = (currL - lastL < m_endError * fabs(lastL)) ? (unchanged + 1) : 0;
            if (iterNum >= m_maxIterNum || unchanged >= 3)
            {
                loop = false;
                ofstream fout("model.txt", ofstream::app);
                fout<<endl;
                for ( j = 0; j < m_stateNum; j++)
                {
                    fout<<m_stateInit[j]<<" ";
                }
                fout<<endl;
                for ( i = 0; i < m_stateNum; i++)
                {
                    for ( j = 0; j < m_stateNum + 1; j++)
                    {
                        fout<<m_stateTran[i][j]<<" ";
                    }
                    fout<<endl;
                }
            }
            //DEBUG
            //cout << "Iter: " << iterNum << ", Average Log-Probability: " << currL << endl;
        }
    
        for ( i = 0; i < m_stateNum; i++)
        {
            delete[] stateTranNum[i];
            delete[] stateFileName[i];
        }
        delete[] stateTranNum;
        delete[] stateFileName;
        delete[] stateFile;
        delete[] stateInitNum;
        delete[] stateDataSize;
    }
  • 相关阅读:
    jbpm4.4使用的hibernate3如何兼容spring5.x及异常Caused by: java.lang.ClassNotFoundException: org.hibernate.impl.SessionImpl
    Caused by: java.lang.ClassNotFoundException: io.netty.resolver.AddressResolverGroup
    Caused by: java.lang.ClassNotFoundException: org.jboss.marshalling.ClassResolver
    Caused by: java.lang.ClassNotFoundException: com.fasterxml.jackson.dataformat.yaml.YAMLFactory
    Redisson报错Caused by: java.lang.IllegalArgumentException: RIVER
    redis中StringRedisTemplate的setIfAbsent方法设置过期时间
    xshell下载
    mysql下载地址
    最小化可行产品MVP
    电梯演讲
  • 原文地址:https://www.cnblogs.com/jerrice/p/4354979.html
Copyright © 2011-2022 走看看