zoukankan      html  css  js  c++  java
  • 条件随机场之CRF++源码详解-特征

      我在学习条件随机场的时候经常有这样的疑问,crf预测当前节点label如何利用其他节点的信息、crf的训练样本与其他的分类器有什么不同、crf的公式中特征函数是什么以及这些特征函数是如何表示的。在这一章中,我将在CRF++源码中寻找答案。

    输入过程

      CRF++训练的入口在crf_learn.cpp文件的main函数中,在该函数中调用了encoder.cpp的crfpp_learn(int argc, char **argv)函数。在CRF++中,训练被称为encoder,显然预测就称为decoder。crfpp_learn的源码如下:

    1 int crfpp_learn(int argc, char **argv) {
    2   CRFPP::Param param; //存放输入的参数
    3   param.open(argc, argv, CRFPP::long_options); //处理命令行输入的参数,存在param对象中
    4   return CRFPP::crfpp_learn(param);
    5 }

      Param对象主要存放输入的参数,调用open方法处理命令行输入的参数并存储。最后调用crfpp_learn(const Param &param)函数,在该函数中将初始化Encoder对象encoder,并调用encoder的learn方法。

    样本的处理以及特征的构造

      本章的重点便是这个learn方法,该方法主要是根据输入的样本和特征模板构造特征。阅读该函数源码之前可以去CRF++官网了解一下CRF++输入的参数,以及模板文件和训练文件的格式。

      1 bool Encoder::learn(const char *templfile, //模板文件
      2                     const char *trainfile, //训练样本
      3                     const char *modelfile, //模型输出文件
      4                     bool textmodelfile, 
      5                     size_t maxitr,
      6                     size_t freq,
      7                     double eta,
      8                     double C,
      9                     unsigned short thread_num,
     10                     unsigned short shrinking_size,
     11                     int algorithm) {
     12   std::cout << COPYRIGHT << std::endl;
     13 
     14   CHECK_FALSE(eta > 0.0) << "eta must be > 0.0"; //CHECK_FALSE是宏定义,如果传入的条件是false,则输出异常信息
     15   CHECK_FALSE(C >= 0.0) << "C must be >= 0.0";
     16   CHECK_FALSE(shrinking_size >= 1) << "shrinking-size must be >= 1";
     17   CHECK_FALSE(thread_num > 0) << "thread must be > 0";
     18 
     19 #ifndef CRFPP_USE_THREAD
     20   CHECK_FALSE(thread_num == 1)
     21       << "This architecture doesn't support multi-thrading";
     22 #endif
     23 
     24   if (algorithm == MIRA && thread_num > 1) {//MIRAS算法无法启用多线程
     25     std::cerr <<  "MIRA doesn't support multi-thrading. use thread_num=1"
     26               << std::endl;
     27   }
     28 
     29   EncoderFeatureIndex feature_index; //所有的特征将存储在feature_index中
     30   Allocator allocator(thread_num); //allocator对象主要用来做资源分配以及回收
     31   std::vector<TaggerImpl* > x; //x存放输入的样本,例如:如果做词性标注的话,TaggerTmpl对象存放的是每句话,而x是所有句子
     32 
     33   std::cout.setf(std::ios::fixed, std::ios::floatfield);
     34   std::cout.precision(5);
     35 
     36 #define WHAT_ERROR(msg) do {                                    
     37     for (std::vector<TaggerImpl *>::iterator it = x.begin();    
     38          it != x.end(); ++it)                                   
     39       delete *it;                                               
     40     std::cerr << msg << std::endl;                              
     41     return false; } while (0)
     42 
     43   CHECK_FALSE(feature_index.open(templfile, trainfile)) //打开“模板文件”和“训练文件”
     44       << feature_index.what();
     45 
     46   {
     47     progress_timer pg;
     48 
     49     std::ifstream ifs(WPATH(trainfile));
     50     CHECK_FALSE(ifs) << "cannot open: " << trainfile;
     51 
     52     std::cout << "reading training data: " << std::flush;
     53     size_t line = 0;
     54     while (ifs) {      //开始读取训练样本
     55       TaggerImpl *_x = new TaggerImpl(); //_x存放的是一句话的内容,CRF++官网中提到,用一个空白行将每个sentence隔开
     56       _x->open(&feature_index, &allocator); //做一些属性赋值,所有的句子都对应相同的feature_index和allocator对象
     57       if (!_x->read(&ifs) || !_x->shrink()) {
     58         WHAT_ERROR(_x->what());
     59       }
     60 
     61       if (!_x->empty()) {
     62         x.push_back(_x);
     63       } else {
     64         delete _x;
     65         continue;
     66       }
     67 
     68       _x->set_thread_id(line % thread_num); //每个句子都会分配一个线程id,可以多线程并发处理不同的句子
     69 
     70       if (++line % 100 == 0) {
     71         std::cout << line << ".. " << std::flush;
     72       }
     73     }
     74 
     75     ifs.close();
     76     std::cout << "
    Done!";
     77   }
     78 
     79   feature_index.shrink(freq, &allocator); // 根据训练是指定的-f参数,将特征出现的频率小于freq的过滤掉
     80 
     81   std::vector <double> alpha(feature_index.size());           // feature_index.size()返回的是maxid_,即:特征函数的个数,alpha是每个特征函数的权重,便是CRF中要学习的参数
     82   std::fill(alpha.begin(), alpha.end(), 0.0);
     83   feature_index.set_alpha(&alpha[0]);
     84 
     85   std::cout << "Number of sentences: " << x.size() << std::endl;
     86   std::cout << "Number of features:  " << feature_index.size() << std::endl;
     87   std::cout << "Number of thread(s): " << thread_num << std::endl;
     88   std::cout << "Freq:                " << freq << std::endl;
     89   std::cout << "eta:                 " << eta << std::endl;
     90   std::cout << "C:                   " << C << std::endl;
     91   std::cout << "shrinking size:      " << shrinking_size
     92             << std::endl;
     93 
     94   ... //省略后续代码 
    95
    }

    我阅读源码是按照深度优先遍历的方式,遇到一个函数会不断地深入进去,直到理解了该函数的功能再返回。上述源码需要重点介绍的部分,我也按照深度优先的方式记录。对于比较容易理解的部分则直接在源码中添加注释。首先看下源码第43行feature_index.open(templfile, trainfile),表面是理解是打开模板文件和训练集文件,但具体做了什么事儿呢,进入这个函数发现分别调用了两个函数。一个是EncoderFeatureIndex::openTemplate(const char *filename),这个函数主要是读取模板文件中的unigram特征和bigram特征分别存储,从官网文章中也可以知道,crf的特征分为两种特征,unigram对应的是状态特征,bigram对应的是转移特征。另一个函数是EncoderFeatureIndex::openTagSet(const char *filename),该函数读取训练集文件,获得训练集特征的数量(feature_index.xsize_属性)以及训练集中label的集合(feature_index.y_属性),以后可以用集合中label值得的索引代替label。

      learn函数的第57行,有两个函数调用。一个是_x->read(&ifs),这个函数是对输入的样本做处理。解释该函数之前,我先做一个约定,以词性标注为例。我们输入的训练样本每一行代表一个词,每一列代表词的特征,多个词(多行)代表一个句子,句子与句子之间用空白行分隔。这个规则从CRF++文档中也能看出,我们就统一用句子和词表示,方便表达。那么,该函数会读取一个句子。经过层层调用,会对_x对象中几个重要的数据结构进行赋值,由于这个函数的处理逻辑不复杂,因此我直接给出最终赋值的结果。如下:

    class TaggerImpl : public Tagger {
      FeatureIndex   *feature_index_;
      Allocator      *allocator_;
      std::vector<std::vector <const char *> > x_; //代表一个句子,外部vector代表多行(多个词),内部vector代表每行的多列,具体的列用char*表示
      std::vector<std::vector <Node *> > node_;    //相当于二位数组,node_[i][j]表示一个节点,即:第i个词是第j个label的点。如:“我”这个词是“代词”
      std::vector<unsigned short int>  answer_;    //每个词对应的label
      std::vector<unsigned short int>  result_;    
    };

     另一个调用是_x->shrink(),该函数的主要功能就是构造特征,具体来说是调用了feature_index的FeatureIndex::buildFeatures(TaggerImpl *tagger)方法,源码如下:

    #define ADD { const int id = getID(os.c_str()); 
      if (id != -1) feature.push_back(id); } while (0)
    bool FeatureIndex::buildFeatures(TaggerImpl *tagger) const {
      string_buffer os;
      std::vector<int> feature;
    
      FeatureCache *feature_cache = tagger->allocator()->feature_cache(); //存放是每个节点或者边对应的特征向量,节点便是node[i][j],边的概念后续会接触,暂时可以忽略
      tagger->set_feature_id(feature_cache->size()); //做个标记,以后要取该句子的特征,可以从该id的位置取
    
      for (size_t cur = 0; cur < tagger->size(); ++cur) {//遍历每个词,计算每个词的特征
        for (std::vector<std::string>::const_iterator it
                 = unigram_templs_.begin();
             it != unigram_templs_.end(); ++it) { //遍历每个unigram特征
          if (!applyRule(&os, it->c_str(), cur, *tagger)) {applyRule函数根据当前词(cur)以及当前的特征(如: %x[-2,0]),生成一个特征,存放在os中
            return false;
          }
          ADD; //将根据特征os,获取该特征的id,如果不存在该特征,生成新的id,将该id添加到feature变量中
        }
        feature_cache->add(feature); //将该词的特征添加到feature_cache中,add方法会将feature拷贝一份并将最后添加-1,方便后续读取
        feature.clear();
      }
    
      for (size_t cur = 1; cur < tagger->size(); ++cur) {//遍历每条边,计算每条边的特征
        for (std::vector<std::string>::const_iterator
                 it = bigram_templs_.begin();
             it != bigram_templs_.end(); ++it) {//遍历每个bigram特征
          if (!applyRule(&os, it->c_str(), cur, *tagger)) {//处理同上
            return false;
          }
          ADD;
        }
        feature_cache->add(feature);
        feature.clear();
      }
    
      return true;
    }

     经过上面处理,最终会存储节点(单词)和边(相邻词连接)的特征列表(函数中feature变量),并存储在feature_cache中,由于在该函数中调用了set_feature_id方法,因此很容易拿到每个句子对应的特征列表。这里需要关注一下applyRule函数和ADD宏定义中的getID函数。下面我将举个例子,来直观感受下这两个函数的功能。

    tempfile:

      # Unigram
      U00:%x[-1,0]
      U01:%x[0,0]

     trainfile:

      0 - -1 -1 -1 -1 O
      0 submit 7 0 0 0 B
      1 submit 3 4 0 0 E

    先看下CRF++中的特征模板,模板文件比较简单,只有unigram特征,特征的表示形如 U00:%x[a,b],开头的'U'代表unigram特征还是bigram特征,b代表的是哪列特征,a代表的是当前词的行偏移量。样本集文件更简单,只有一个句子,该句子有3个单词,每个单词有6个特征。

    1) 当cur=0,遍历第一个unigram特征U00:%x[-1,0], 0代表第0个特征(第0列),-1代表前一个词的第0个特征。由于第一个词没有前一个词,所以CRF++中用_B-1代替,这部分可在源码中找到。调用applyRule将会生成"U00:_B-1"特征,调用getID函数返回的maxid_并存储在feature_index的dic_属性中,maxid_初始值为0,如果当前特征是新的则返回maxid_并更新maxid_为新值,maxid更新代码为maxid_ += (key[0] == 'U' ? y_.size() : y_.size() * y_.size()); 由于unigram是状态特征label与当前节点有关,所以加y_.size()表示y_.size()个特征函数,而bigram表示转移特征(边),与当前状态和前一个状态有关,有y_.size() * y_.size()种情况,因此加上y_.size() * y_.size(),代表y_.size()*y_.size()个特征函数。以上述例子unigram来说,对于某个词的特征,该词的label可能有y_.size()种情况,最终生成的特征函数是 f(特征='U00:_B-1', y='O')=1,f(特征='U00:_B-1', y='B')=1,f(特征='U00:_B-1', y='E')=1。总结一下,对于这个例子来说,一个unigram特征对应3状态特征函数,一个bigram特征对应9个转移特征函数。

    2) 当cur=0,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:0",调用getID函数,返回特征id为3,feature变量为[0,3]

    3) 当cur=1,遍历第一个unigram特征U00:%x[-1,0],调用applyRule生成特征"U00:0",调用getID函数,返回特征id为6

    4) 当cur=1,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:0",调用getID函数,返回特征id为3, feature变量为[6,3]

    5) 当cur=2,遍历第一个unigram特征U00:%x[-1,0],调用applyRule生成特征"U00:0",调用getID函数,返回特征id为6

    6) 当cur=2,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:1",调用getID函数,返回特征id为9, 此时maxid_更新为12,feature变量为[6,9]

    因此,特征一共有4个,状态特征有12个,转移特征为0个,因此feature_index的maxid_为12,feature_cache的大小为5(3个节点+2条边)。本例子中只有1句话并且只有一个特征的unigram特征函数,对于多句话和多个特征函数,计算逻辑是一样的,并且都会更新到公共的变量feature_index中。

     至此,就_x->shrink()的核心逻辑便梳理完毕, 同时也是整个learn函数的核心逻辑,回到learn函数的源码继续往下看,while循环会对每个句子重复进行上述操作,并将表示句子的变量x_存储到变量x中,代表整个训练集。还有需要注意的是我们平时一般用w表示待学习的参数,但在CRF++中使用变量alpha表示w。

    总结

      本章主要结合源码和实际的例子,了解了CRF++如何处理输入的样本,如何生成特征以及特征函数。首先,通过本章可以清晰的找到开头提到的几个问题。其次,可以学习CRF++如何定义数据结构表示条件随机场各个元素及其之间的关系,如果再仔细体会一下,就能发现CRF++里设计的数据结构和代码实现还是非常巧妙的,值得学习。如对本章内容有疑问的欢迎在留言区交流,我会及时回复,同时如有表述不对的地方,烦请指正。

      

  • 相关阅读:
    Pytorch-基于Transformer的情感分类
    Pytorch-LSTM+Attention文本分类
    .NET ------ 批量修改
    idea ---- idea中关联GitHub
    .NET ----- 将文本框改成下划线,将下拉框改为下拉下划线
    表设计(省市县)
    锁:并发编程中的三个问题(可见性、原子性、有序性)
    freemarker:常用指令、null值的处理、基本数据类型、自定义指令
    vue:绑定属性指令(绑定属性、绑定class(对象语法、数组语法))
    vue:指令(插值操作、指令(v-once、v-html、v-text、v-pre、v-cloak))
  • 原文地址:https://www.cnblogs.com/duma/p/10293190.html
Copyright © 2011-2022 走看看