zoukankan      html  css  js  c++  java
  • Chain训练准则的计算

    1000轮迭代时验证集的日志:

    log/compute_prob_valid.1000.log

    LOG (nnet3-chain-compute-prob[5.5.100-d66be]:PrintTotalStats():nnet-chain-diagnostics.cc:194) Overall log-probability for 'output-xent' is -2.14993 per frame, over 18230 frames.

    LOG (nnet3-chain-compute-prob[5.5.100-d66be]:PrintTotalStats():nnet-chain-diagnostics.cc:194) Overall log-probability for 'output' is -0.238675 per frame, over 18230 frames.

    其中的Overall log-probability是指,这次迭代(iteration)的平均损失函数的值。

       

    对于chain,其损失函数为LF-MMI

       

    其中

       

    nnet3/nnet-chain-diagnostics.cc

    void NnetChainComputeProb::ProcessOutputs(const NnetChainExample &eg,

    NnetComputer *computer) {

    std::vector<NnetChainSupervision>::const_iterator iter = eg.outputs.begin(),

    end = eg.outputs.end();

    for (; iter != end; ++iter) {

    BaseFloat tot_like, tot_l2_term, tot_weight;

    //...

    ComputeChainObjfAndDeriv(chain_config_, den_graph_,

    sup.supervision, nnet_output,

    &tot_like, &tot_l2_term, &tot_weight,

    (nnet_config_.compute_deriv ? &nnet_output_deriv :

    NULL), (use_xent ? &xent_deriv : NULL));

    //...

    ChainObjectiveInfo &totals = objf_info_[sup.name];

    totals.tot_weight += tot_weight;

    totals.tot_like += tot_like;

    totals.tot_l2_term += tot_l2_term;

    //...

    }

    }

       

    void ComputeChainObjfAndDeriv(...){

    *objf = num_logprob_weighted - den_logprob_weighted;

    //supervision.weight:样本(egs)的权重,通常为1.0

    //supervision.num_sequencesSupevision对象(由lattice或对齐生成)的数量,即FST的数量,或语句的数量

    //supervision.frames_per_sequence:每个Supevision中的帧数

    //weight即一个archive中的带权帧数

    *weight = supervision.weight * supervision.num_sequences *

    supervision.frames_per_sequence;

    }

    //似然即一个archive的平均对数似然

    BaseFloat like = (info.tot_like / info.tot_weight),

    //一个archive的平均L2正则化项

    l2_term = (info.tot_l2_term / info.tot_weight),

    //一个archive的平均准则函数值

    tot_objf = like + l2_term;

       

    由于MMI的目标是最大化互信息值,因此,需要对准则函数进行最大化,或对负准则函数进行最小化。

       

    因此,以下日志中的"Overall log-probability"值越大越好。

    log/compute_prob_valid.1000.log

    LOG (nnet3-chain-compute-prob[5.5.100-d66be]:PrintTotalStats():nnet-chain-diagnostics.cc:194) Overall log-probability for 'output-xent' is -2.14993 per frame, over 18230 frames.

    LOG (nnet3-chain-compute-prob[5.5.100-d66be]:PrintTotalStats():nnet-chain-diagnostics.cc:194) Overall log-probability for 'output' is -0.238675 per frame, over 18230 frames.

       

  • 相关阅读:
    Java 高阶 —— try/catch
    Java 高阶 —— native 关键字与 JNI
    python库学习笔记——分组计算利器:pandas中的groupby技术
    编程模式(schema) —— 表驱动法(table-driven)
    python中元组tuple
    .Net Framwork类库
    SMB带宽限制
    WindDbug应用
    Python学习笔记
    Python递归遍历目录下所有文件
  • 原文地址:https://www.cnblogs.com/JarvanWang/p/10145889.html
Copyright © 2011-2022 走看看