zoukankan      html  css  js  c++  java
  • 条件随机场之CRF++源码详解-预测

      这篇文章主要讲解CRF++实现预测的过程,预测的算法以及代码实现相对来说比较简单,所以这篇文章理解起来也会比上一篇条件随机场训练的内容要容易。

    预测

      上一篇条件随机场训练的源码详解中,有一个地方并没有介绍。 就是训练结束后,会把待优化权重alpha等变量保存到文件中,也就是输出到指定的模型文件。在执行预测的时候会从模型文件读出相关的变量,这个过程其实就是数据序列化与反序列化,该过程跟条件随机场算法关系不大,因此为了突出重点源码解析里就没有介绍这部分,有兴趣的朋友可以自己研究一下。

      CRF++预测的入口代码在crf_test.cpp的main函数中,最终会调用tragger.cpp的int crfpp_test(const Param &param)函数,期间会做一些输入参数的处理、异常处理、读取模型文件等操作。一切准备就绪就会打开待预测的文件,进行预测。正式探讨预测代码之前,我们先看下预测的理论基础。条件随机场的预测用到了维特比算法,公式如下:

    egin{aligned} y^* &= arg max_yP_w(y|x) \ &=  arg max_yfrac{ exp left { sum_{k=1}^Kw_kf_k(y,x) ight}}{Z_w(x)} \ &=  arg max_y exp left {sum_{k=1}^Kw_kf_k(y,x) ight} \ &= arg max_y sum_{k=1}^Kw_kf_k(y,x) end{aligned}

    从公式我们可以看出,我们求的概率最大值就是要求代价最大。接下来就看下CRF++的源码,代码在tragger.cpp的crfpp_test函数中:

    while (*is) {//is是打开的测试文件,可以输入多个测试文件做预测
          tagger.parse_stream(is.get(), os.get()); 
    }
    
    bool TaggerImpl::parse_stream(std::istream *is,
                                  std::ostream *os) {
      if (!read(is) || !parse()) {//read函数在特征篇讲过,不再赘述,调用parse函数进行预测
        return false;
      }
      if (x_.empty()) {
        return true;
      }
      toString(); //格式化输出,-v 会输出每个词预测为某个label的概率,-n会输出预测序列概率最大的前n个,如果理解上一篇训练过程,再看这个函数就比较容易理解,无非就是概率计算,这里不再赘述
      os->write(os_.data(), os_.size()); //输出到输出文件
      return true;
    }
    bool TaggerImpl::parse() {
      CHECK_FALSE(feature_index_->buildFeatures(this)) //构建特征,同特征篇代码,不再赘述
          << feature_index_->what();
    
      if (x_.empty()) {
        return true;
      }
      buildLattice(); //构建无向图,因为要计算代价最大的序列,训练篇讲过,不再赘述
      if (nbest_ || vlevel_ >= 1) {
        forwardbackward(); //前向后向算法,为了计算单词节点的概率,训练篇讲过,不再赘述
      }
      viterbi();  //维特比算法, 做预测的代码
      if (nbest_) {
        initNbest();
      }
    
      return true;
    }
    void TaggerImpl::viterbi() {
      for (size_t i = 0;   i < x_.size(); ++i) { //遍历每个词
        for (size_t j = 0; j < ysize_; ++j) { //遍历每个词的每个label
          double bestc = -1e37;
          Node *best = 0;
          const std::vector<Path *> &lpath = node_[i][j]->lpath;
          for (const_Path_iterator it = lpath.begin(); it != lpath.end(); ++it) { //从前一个词到当前词的代价之和 = max(前一个节点的代价 + 前一个节点的边代价 + 当前节点代价)
            double cost = (*it)->lnode->bestCost +(*it)->cost +
                node_[i][j]->cost;
            if (cost > bestc) { //记录截止当前节点最大的代价, 以及对应的前一个节点
              bestc = cost;
              best  = (*it)->lnode;
            }
          }
          node_[i][j]->prev     = best; //记录前一个几点
          node_[i][j]->bestCost = best ? bestc : node_[i][j]->cost; //记录最大的代价值, 如果best = 0代表第一个词,没有左边,最大代价就是节点的代价node_[i][j]->cost
        }
      }
    
      double bestc = -1e37;
      Node *best = 0;
      size_t s = x_.size()-1;
      for (size_t j = 0; j < ysize_; ++j) { //遍历最后一个词的节点,截止到最后一个词的代价最大值就是整个句子的最大代价
        if (bestc < node_[s][j]->bestCost) {
          best  = node_[s][j];
          bestc = node_[s][j]->bestCost;
        }
      }
    
      for (Node *n = best; n; n = n->prev) {//记录代价最大的预测序列
        result_[n->x] = n->y;
      }
    
      cost_ = -node_[x_.size()-1][result_[x_.size()-1]]->bestCost;
    }

     预测的核心代码就看完了,大部分复用了训练过程的逻辑。可以看到预测的过程跟公式是一致的,无非就是求能够让代价最大的label序列(标记序列),这就是维特比算法。

    总结

      至此,我们的条件随机场之CRF++源码详解系列就结束了,主要涵盖了特征处理、训练以及预测三个核心过程。结合CRF++源码我们可以更形象的、更通俗的去理解条件随机场模型。以后想起条件随机场模型,我们脑海浮现的不再是一堆公式,而是一个无向图,在图上进行代价计算、前向后向计算、期望值的计算以及梯度的计算等一系列的过程。希望这个系列对于正在学习条件随机场的朋友能有帮助,如果本文阐述的有歧义、不通俗、不容易理解的地方,欢迎留言区交流,我将及时更正、回复,希望我们一起提高。

  • 相关阅读:
    2020年秋招三星面试题
    物联网金融和互联网金融的区别与联系
    数据库事务的4种隔离级别
    Access-cookie之sqlmap注入
    SDL-软件安全开发周期流程
    图片马的制作
    ssrf内网端口爆破扫描
    逻辑漏洞_验证码绕过_密码找回漏洞
    平行越权与垂直越权
    xff注入
  • 原文地址:https://www.cnblogs.com/duma/p/10344232.html
Copyright © 2011-2022 走看看