zoukankan      html  css  js  c++  java
  • kaldi chain模型的序列鉴别性训练代码分析

    chainbin/nnet3-chain-train.cc

    int main(int argc, char *argv[]) {

    ...

    Nnet nnet;

    ReadKaldiObject(nnet_rxfilename, &nnet);

    bool ok;

    {

    fst::StdVectorFst den_fst;

    ReadFstKaldi(den_fst_rxfilename, &den_fst);

     

    //NnetChainTrainer读取训练参数opts、分母词图den_fst、神经网络nnet

    NnetChainTrainer trainer(opts, den_fst, &nnet);

    //SequentialNnetChainExampleReader以语句为单位读取样本

    SequentialNnetChainExampleReader example_reader(examples_rspecifier);

    for (; !example_reader.Done(); example_reader.Next())

    //以句为单位进行训练

    trainer.Train(example_reader.Value());

    ok = trainer.PrintTotalStats();

    }n

    ...

    WriteKaldiObject(nnet, nnet_wxfilename, binary_write);

    ...

    }

    nnet3/nnet-chain-training.cc

    void NnetChainTrainer::Train(const NnetChainExample &chain_eg) {

    bool need_model_derivative = true;

    const NnetTrainerOptions &nnet_config = opts_.nnet_config;

    bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0);

    ComputationRequest request;

    //This function takes a NnetChainExample and produces a ComputationRequest.

    GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative,

    nnet_config.store_component_stats,

    use_xent_regularization, need_model_derivative,

    &request);

    //进行编译,返回到结果的常量指针。

    //返回的常量指针由CachingOptimizingCompiler NnetChainTrainer::compiler_所有

    //如果编译失败,用std::shared_ptr<const NnetComputation>接收返回值

    std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);

     

       

    if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_

    % nnet_config.backstitch_training_interval ==

    srand_seed_ % nnet_config.backstitch_training_interval) {

    // backstitch training is incompatible with momentum > 0

    KALDI_ASSERT(nnet_config.momentum == 0.0);

    FreezeNaturalGradient(true, delta_nnet_);

    bool is_backstitch_step1 = true;

    srand(srand_seed_ + num_minibatches_processed_);

    ResetGenerators(nnet_);

    TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);

    FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient

    is_backstitch_step1 = false;

    srand(srand_seed_ + num_minibatches_processed_);

    ResetGenerators(nnet_);

    TrainInternalBackstitch(chain_eg, *computation, is_backstitch_step1);

    } else { // conventional training

    TrainInternal(chain_eg, *computation);

    }

       

    num_minibatches_processed_++;

    }

       

       

    void NnetChainTrainer::TrainInternal(const NnetChainExample &eg,

    const NnetComputation &computation) {

    //NnetComputer类负责执行"computation"对象描述的计算。

    //以以下顺序调用:

    构造函数

    AcceptInput()【或AcceptInputs()

    Run()

    GetOutput()

    AcceptOutputDeriv()【若可用】

    Run()【如果需要反向计算】

    GetInputDeriv()【若可用】:

    NnetComputer computer(nnet_config.compute_config, computation,

    nnet_, delta_nnet_);

    computer.AcceptInputs(*nnet_, eg.inputs);

    //前向传播,计算

    computer.Run();

    //该函数调用了GetOutput()

    this->ProcessOutputs(false, eg, &computer);

    //反向传播,计算权重更新量delta_nnet_

    computer.Run();

    //根据L2正则化项,修改权重更新量delta_nnet_

    ApplyL2Regularization(*nnet_,

    GetNumNvalues(eg.inputs, false) *

    nnet_config.l2_regularize_factor,

    delta_nnet_);

    //根据权重更新量delta_nnet_,更新神经网络,上限为nnet_config.max_param_change

    bool success =

    UpdateNnetWithMaxChange(*delta_nnet_,

    nnet_config.max_param_change,

    1.0,

    1.0 - nnet_config.momentum,

    nnet_,

    &num_max_change_per_component_applied_,

    &num_max_change_global_applied_);

      

       

       

    void NnetChainTrainer::ProcessOutputs(bool is_backstitch_step2,

    const NnetChainExample &eg,

    NnetComputer *computer) {

    // normally the eg will have just one output named 'output', but

    // we don't assume this.

    // In backstitch training, the output-name with the "_backstitch" suffix is

    // the one computed after the first, backward step of backstitch.

    const std::string suffix = (is_backstitch_step2 ? "_backstitch" : "");

    std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),

    end = eg.outputs.end();

    for (; iter != end; ++iter) {

    //检查每个样本的标签是否与网络相匹配

    const NnetChainSupervision &sup = *iter;

    int32 node_index = nnet_->GetNodeIndex(sup.name);

    if (node_index < 0 ||

    !nnet_->IsOutputNode(node_index))

    KALDI_ERR << "Network has no output named " << sup.name;

       

    const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);

    CuMatrix<BaseFloat> nnet_output_deriv(nnet_output.NumRows(),

    nnet_output.NumCols(),

    kUndefined);

    //是否进行交叉熵正则化

    bool use_xent = (opts_.chain_config.xent_regularize != 0.0);

    //从名为"output-xent"的component-node获取交叉熵的目标函数值

    std::string xent_name = sup.name + "-xent"; // typically "output-xent".

    CuMatrix<BaseFloat> xent_deriv;

    //tot_objf,目标函数值,未包含L2正则化项,未包含交叉熵正则化项

    //tot_l2_termL2正则化项

    //tot_weightL2正则化项权重

    BaseFloat tot_objf, tot_l2_term, tot_weight;

    //根据预测和标签计算目标函数值及其梯度,计算交叉熵正则化项及其权重

       

    //帧平滑-序列鉴别性准则

    ComputeChainObjfAndDeriv(opts_.chain_config, den_graph_,

    sup.supervision, nnet_output,

    &tot_objf, &tot_l2_term, &tot_weight,

    &nnet_output_deriv,

    (use_xent ? &xent_deriv : NULL));

       

    //更新梯度统计量

    if (use_xent) {

    // 从神经网络中获取交叉熵output-node的输出

    const CuMatrixBase<BaseFloat> &xent_output = computer->GetOutput(

    xent_name);

    /* 此时,xent_derivMMI准则函数的分子后验/分子错误信号。

    /*

    BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans);

    objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix,

    opts_.nnet_config.print_interval,

    num_minibatches_processed_,

    tot_weight, xent_objf);

    }

    //乘以梯度权重

    if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) {

    CuVector<BaseFloat> cu_deriv_weights(sup.deriv_weights);

    nnet_output_deriv.MulRowsVec(cu_deriv_weights);

    if (use_xent)

    //xent_deriv=diag(cu_deriv_weights)*xent_deriv

    //cu_deriv_weights[i]xent_deriv的第i行进行缩放

    xent_deriv.MulRowsVec(cu_deriv_weights);

    }

    //计算器接收梯度

    computer->AcceptInput(sup.name, &nnet_output_deriv);

     

    objf_info_[sup.name + suffix].UpdateStats(sup.name + suffix,

    opts_.nnet_config.print_interval,

    num_minibatches_processed_,

    tot_weight, tot_objf, tot_l2_term);

       

    if (use_xent) {

    //以交叉熵正则化因子进行缩放

    xent_deriv.Scale(opts_.chain_config.xent_regularize);

    //接收交叉熵正则化的梯度

    computer->AcceptInput(xent_name, &xent_deriv);

    }

    }

    }

    chain/chain-training.cc

    //该函数只计算交叉熵正则化项所需的数据,但并不在梯度中应用交叉熵正则化项!
    void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts,

    const DenominatorGraph &den_graph,

    const Supervision &supervision,

    const CuMatrixBase<BaseFloat> &nnet_output,

    BaseFloat *objf,

    BaseFloat *l2_term,

    BaseFloat *weight,

    CuMatrixBase<BaseFloat> *nnet_output_deriv,

    CuMatrix<BaseFloat> *xent_output_deriv) {

       

    if (!supervision.e2e_fsts.empty()) {

    ComputeChainObjfAndDerivE2e(opts, den_graph, supervision,

    nnet_output, objf, l2_term,

    weight, nnet_output_deriv, xent_output_deriv);

    return;

    }

       

    BaseFloat num_logprob_weighted, den_logprob_weighted;

    bool ok = true;

    if (nnet_output_deriv != NULL)

    nnet_output_deriv->SetZero();

       

    { // Doing the denominator first helps to reduce the maximum

    // memory use, as we can set 'xent_deriv' to nonempty after

    // we've freed the memory in this object.

    DenominatorComputation denominator(opts, den_graph,

    supervision.num_sequences,

    nnet_output);

    /*

    denominator.Forward()的结果为分母词图的后验概率

    */

    den_logprob_weighted = supervision.weight * denominator.Forward();

    if (nnet_output_deriv)

    //其中负号来自于对分母取log

    ok = denominator.Backward(-supervision.weight,

    nnet_output_deriv);

    }

       

    if (xent_output_deriv != NULL) {

    // the reason for kStrideEqualNumCols is so that we can share the memory

    // block with the memory that was used for exp_nnet_output_transposed_ from

    // chain-denominator.cc, which has just been freed; it also uses the

    // kStrideEqualNumCols arg (its shape is the transpose of this matrix's

    // shape).

    xent_output_deriv->Resize(nnet_output.NumRows(), nnet_output.NumCols(),

    kSetZero, kStrideEqualNumCols);

    }

       

    {

    /*supervision是一句话完整标注对应的分子词图,其中包含每个音素序列的时间范围信息

    其中

    相当于nnet_output

    */

    //NumeratorComputation类负责'supervision'(分子)FST的前向后向计算

    NumeratorComputation numerator(supervision, nnet_output);

    // note: supervision.weight is included as a factor in the derivative from

    // the numerator object, as well as the returned logprob.

    */

    分子词图的后验概率

    这与Kaldi nnet1

    为神经网络后验概率

    不同,Kaldi nnet3直接对分子词图进行计算

    由于词图包含了

    状态分布(NN)、状态、音素、字的全部信息。

    因此,对词图的前向后向计算后,得到的是后验概率

    */

    num_logprob_weighted = numerator.Forward();

    //此处,无法是否进行交叉熵正则化,

    //序列鉴别性训练的梯度nnet_output_deriv都不变。

    //此时,还并没有在梯度中应用交叉熵正则化项!

    if (xent_output_deriv)

    {

    numerator.Backward(xent_output_deriv);

    if (nnet_output_deriv)

    D维梯度向量

    nnet_output_deriv->AddMat(1.0, *xent_output_deriv);

    }

    else if (nnet_output_deriv)

    {

    D维梯度向量

       

    numerator.Backward(nnet_output_deriv);

    }

       

       

    }

    /*

    */

       

    *objf = num_logprob_weighted - den_logprob_weighted;

       

    *weight = supervision.weight * supervision.num_sequences *

    supervision.frames_per_sequence;

    //若梯度为无穷大/不可用 分母计算出错

    if (!((*objf) - (*objf) == 0) || !ok) {

    // inf or NaN detected, or denominator computation returned false.

    if (nnet_output_deriv)

    //将梯度设为零

    nnet_output_deriv->SetZero();

    if (xent_output_deriv)

    //将交叉熵梯度设为零

    xent_output_deriv->SetZero();

    BaseFloat default_objf = -10;

    KALDI_WARN << "Objective function is " << (*objf)

    << " and denominator computation (if done) returned "

    << std::boolalpha << ok

    << ", setting objective function to " << default_objf

    << " per frame.";

    //将权重设置为加权默认权重

    *objf = default_objf * *weight;

    }

       

    // This code helps us see how big the derivatives are, on average,

    // for different frames of the sequences. As expected, they are

    // smaller towards the edges of the sequences (due to the penalization

    // of 'incorrect' pdf-ids.

    if (GetVerboseLevel() >= 1 && nnet_output_deriv != NULL && RandInt(0, 10) == 0) {

    int32 tot_frames = nnet_output_deriv->NumRows(),

    frames_per_sequence = supervision.frames_per_sequence,

    num_sequences = supervision.num_sequences;

    CuVector<BaseFloat> row_products(tot_frames);

    row_products.AddDiagMat2(1.0, *nnet_output_deriv, kNoTrans, 0.0);

    Vector<BaseFloat> row_products_cpu(row_products);

    Vector<BaseFloat> row_products_per_frame(frames_per_sequence);

    for (int32 i = 0; i < tot_frames; i++)

    row_products_per_frame(i / num_sequences) += row_products_cpu(i);

    KALDI_LOG << "Derivs per frame are " << row_products_per_frame;

    }

       

    if (opts.l2_regularize == 0.0) {

    *l2_term = 0.0;

    } else {

    // compute the l2 penalty term and its derivative

    BaseFloat scale = supervision.weight * opts.l2_regularize;

    //计算L2正则化项

    *l2_term = -0.5 * scale * TraceMatMat(nnet_output, nnet_output, kTrans);

    if (nnet_output_deriv)

    //

    nnet_output_deriv->AddMat(-1.0 * scale, nnet_output);

    }

    }

       

    chain/chain-numerator.cc

    //进行前向计算,返回 总对数似然 * supervision_.weight
    BaseFloat NumeratorComputation::Forward() {

    ComputeLookupIndexes();

    nnet_logprobs_.Resize(nnet_output_indexes_.Dim(), kUndefined);

    nnet_output_.Lookup(nnet_output_indexes_, nnet_logprobs_.Data());

    const fst::StdVectorFst &fst = supervision_.fst;

    KALDI_ASSERT(fst.Start() == 0);

    int32 num_states = fst.NumStates();

    log_alpha_.Resize(num_states, kUndefined);

    log_alpha_.Set(-std::numeric_limits<double>::infinity());

    tot_log_prob_ = -std::numeric_limits<double>::infinity();

       

    log_alpha_(0) = 0.0; // note, state zero is the start state, we checked above

       

    const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();

    std::vector<int32>::const_iterator fst_output_indexes_iter =

    fst_output_indexes_.begin();

       

    double *log_alpha_data = log_alpha_.Data();

       

    for (int32 state = 0; state < num_states; state++) {

    double this_log_alpha = log_alpha_data[state];

    for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state); !aiter.Done();

    aiter.Next(), ++fst_output_indexes_iter) {

    const fst::StdArc &arc = aiter.Value();

    int32 nextstate = arc.nextstate;

    BaseFloat transition_logprob = -arc.weight.Value();

    int32 index = *fst_output_indexes_iter;

    BaseFloat pseudo_loglike = nnet_logprob_data[index];

    double &next_log_alpha = log_alpha_data[nextstate];

    next_log_alpha = LogAdd(next_log_alpha, pseudo_loglike +

    transition_logprob + this_log_alpha);

    }

    if (fst.Final(state) != fst::TropicalWeight::Zero()) {

    BaseFloat final_logprob = -fst.Final(state).Value();

    tot_log_prob_ = LogAdd(tot_log_prob_,

    this_log_alpha + final_logprob);

    }

    }

    KALDI_ASSERT(fst_output_indexes_iter ==

    fst_output_indexes_.end());

    return tot_log_prob_ * supervision_.weight;

    }

       

       

    //进行后向计算,计算神经网络输出的导数

    // 对数似然 * supervision_.weight * deriv_weight

    //加到nnet_output_deriv
    void NumeratorComputation::Backward(

    CuMatrixBase<BaseFloat> *nnet_output_deriv) {

    //分子词图

    const fst::StdVectorFst &fst = supervision_.fst;

    //分子词图的状态数

    int32 num_states = fst.NumStates();

    log_beta_.Resize(num_states, kUndefined);

    //神经网络对数似然导数向量

    nnet_logprob_derivs_.Resize(nnet_logprobs_.Dim());

       

    // we'll be counting backwards and moving the 'fst_output_indexes_iter'

    // pointer back.

    //'fst_output_indexes'包含监督FST中每个弧的条目,如果按顺序访问每个状态的每个弧,则获得它们时也是顺序的。 fst_output_indexes_的内容是nnet_output_indexes_nnet_logprobs_的索引。

    const int32 *fst_output_indexes_iter = &(fst_output_indexes_[0]) +

    fst_output_indexes_.size();

    //CPU上的nnet输出中查找获得的log-probs。此向量与nnet_output_indexes_具有相同的大小。在反向计算中,将被重新用于存储导数。

    const BaseFloat *nnet_logprob_data = nnet_logprobs_.Data();

    //tot_log_prob_是前向后向计算中得到的总伪对数似然

    double tot_log_prob = tot_log_prob_;

    double *log_beta_data = log_beta_.Data();

    const double *log_alpha_data = log_alpha_.Data();

    //nnet_logprob_derivs_是关于神经网络对数似然的导数。可以理解为占有概率

    BaseFloat *nnet_logprob_deriv_data = nnet_logprob_derivs_.Data();

    //遍历分子词图中的每个状态

    for (int32 state = num_states - 1; state >= 0; state--) {

    //与该状态相连的弧数量

    int32 this_num_arcs = fst.NumArcs(state);

    // on the backward pass we access the fst_output_indexes_ vector in a zigzag

    // pattern.

    //fst_output_indexes_iter是前向计算中统计的所有弧的数量

    fst_output_indexes_iter -= this_num_arcs;

    const int32 *this_fst_output_indexes_iter = fst_output_indexes_iter;

    double this_log_beta = -fst.Final(state).Value();

    double this_log_alpha = log_alpha_data[state];

    //遍历与状态相连的所有弧

    for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, state); !aiter.Done();

    aiter.Next(), this_fst_output_indexes_iter++) {

    const fst::StdArc &arc = aiter.Value();

    double next_log_beta = log_beta_data[arc.nextstate];

    BaseFloat transition_logprob = -arc.weight.Value();

    //t

    int32 index = *this_fst_output_indexes_iter;

    BaseFloat pseudo_loglike = nnet_logprob_data[index];

    /*累加:

     

    */

    this_log_beta = LogAdd(this_log_beta, pseudo_loglike +

    transition_logprob + next_log_beta);

    //分子的后验占用率

    BaseFloat occupation_logprob = this_log_alpha + pseudo_loglike +

    transition_logprob + next_log_beta - tot_log_prob,

    occupation_prob = exp(occupation_logprob);

    nnet_logprob_deriv_data[index] += occupation_prob;

    }

    // check for -inf.

    KALDI_PARANOID_ASSERT(this_log_beta - this_log_beta == 0);

    log_beta_data[state] = this_log_beta;

    }

    KALDI_ASSERT(fst_output_indexes_iter == &(fst_output_indexes_[0]));

       

    int32 start_state = 0; // the fact that the start state is numbered 0 is

    // implied by other properties of the FST

    // (epsilon-free-ness and topological sorting, and

    // connectedness).

    double tot_log_prob_backward = log_beta_(start_state);

    if (!ApproxEqual(tot_log_prob_backward, tot_log_prob_))

    KALDI_WARN << "Disagreement in forward/backward log-probs: "

    << tot_log_prob_backward << " vs. " << tot_log_prob_;

       

    // copy this data to GPU.

    CuVector<BaseFloat> nnet_logprob_deriv_cuda;

    nnet_logprob_deriv_cuda.Swap(&nnet_logprob_derivs_);

    /*nnet_output_indexes是一个(行,列)索引的列表,我们需要在nnet_output_中查找前向后向计算。 顺序是任意的,但是这个向量中的索引出现在fst_output_indexes; 并且重要的是每对只出现一次(为了使导数正确相加)。

    (行,列)=PDFS数,特征数)

    matrix-common.h:69

    nnet_output_deriv(nnet_output_indexes_[i].first, nnet_output_indexes_[i].second) +=

    supervision_.weight * nnet_logprob_deriv_cuda.Data()[i];

    */

    nnet_output_deriv->AddElements(supervision_.weight, nnet_output_indexes_,

    nnet_logprob_deriv_cuda.Data());

    }

     

  • 相关阅读:
    Linux 管道 管线
    中国网络拓扑图(主干网、地区网、主节点一览表)
    C++中的STL中map用法详解
    Git 丢弃本地修改
    Linux下指定so动态库的加载路径的5种方法!!!
    C++中string append函数的使用与字符串拼接
    C++执行shell命令linux
    Linux动态库.so文件加载搜索路径详解
    FFmpeg打印日志函数分析
    linux C/C++中调用shell命令和运行shell脚本
  • 原文地址:https://www.cnblogs.com/JarvanWang/p/10281089.html
Copyright © 2011-2022 走看看