zoukankan      html  css  js  c++  java
  • Kaldi attention解析

    xconfig示例

    num_targets=3766

    learning_rate_factor=20

    dir=`mktemp -d`

    mkdir -p $dir/configs

    cat <<EOF > $dir/configs/network.xconfig

    input dim=71 name=input

    attention-relu-renorm-layer name=attention1 num-heads=5 value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 time-stride=3

    output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5

    EOF

    (cd ~/kaldi/egs/wsj/s5;steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/)

    config示例

    component name=attention1.attention type=RestrictedAttentionComponent value-dim=40 key-dim=20 num-left-inputs=5 num-right-inputs=2 num-left-inputs-required=-1 num-right-inputs-required=-1 output-context=True time-stride=3 num-heads=5 key-scale=0.158113883008

    component-node name=attention1.attention component=attention1.attention input=attention1.affine

    raw.txt示例

    <ComponentName> attention1.attention <RestrictedAttentionComponent> <NumHeads> 5 <KeyDim> 20 <ValueDim> 40 <NumLeftInputs> 5 <NumRightInputs> 2 <TimeStride> 3 <NumLeftInputsRequired> 5 <NumRightInputsRequired> 2 <OutputContext> T <KeyScale> 0.1581139 <StatsCount> 0 <EntropyStats> [ ]

    <PosteriorStats> [ ]

    </RestrictedAttentionComponent>

    拓扑结构

    根据拓扑结构可知,kaldi nnet3 RestrictedAttentionComponent相当于一个非线性层

    gdb示例

    $ gdb -d ~/kaldi/src/nnet3 --args nnet3-compute ref.raw ark,t:/tmp/feat ark:/dev/null

    (gdb) rb kaldi::nnet3::.*::Propagate

    (gdb) run

    Breakpoint 3, kaldi::nnet3::AffineComponent::Propagate (this=0x11a6ec80, indexes=0x0, in=...,

    out=0x7fffffffb790) at nnet-simple-component.cc:1236

    输入为71x71的矩阵

    (gdb) printf "%d, %d, %d, %d ", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

    71, 71, 71, 440

    (gdb) c

    Breakpoint 43, kaldi::nnet3::RestrictedAttentionComponent::Propagate (this=0x12f07e40, indexes_in=0x12f09a00,

    in=..., out=0x7fffffffb790) at nnet-attention-component.cc:134

    输入为71x440的矩阵,分为5heads

    Head 1

    Head 2

    Head 3

    Head 4

    Head 5

    71x88

    71x88

    71x88

    71x88

    71x88

    (gdb) printf "%d, %d, %d, %d ", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

    71, 440, 50, 240

    此处对每个head分别进行attention,即PropagateOneHead

    (gdb) c

    Breakpoint 44, kaldi::nnet3::RestrictedAttentionComponent::PropagateOneHead (this=this@entry=0x12f07e40,

    io=..., in=..., c=c@entry=0x7fffffffb630, out=out@entry=0x7fffffffb650) at nnet-attention-component.cc:164

    164 CuMatrixBase<BaseFloat> *out) const {

    (gdb) printf "%d, %d, %d, %d ", in.NumRows(), in.NumCols(), out->NumRows(), out->NumCols()

    71, 88, 50, 48

    71帧中包含了

    1. num-left-inputs*time-stride=5*3=15帧左上文,不输出
    2. 中间50帧,输出
    3. num-right-inputs*time-stride=2*3=6帧右上文,不输出

    PropagateOneHead的计算示例为:

       

    整个RestrictedAttentionComponent的计算逻辑图为:

       

  • 相关阅读:
    网站性能在线评估
    如何测试电梯/伞/桌子/笔?
    apk反编译查看源码
    Jmeter(四)-断言/检查点
    【转】Jmeter(三)-简单的HTTP请求(非录制)
    【转】Jmeter(二)-使用代理录制脚本
    [转]Jmeter(一)-精简测试脚本
    CentOS 安装以及配置Apache php mysql
    centOS静态ip设置
    免费DDOS攻击测试工具大合集
  • 原文地址:https://www.cnblogs.com/JarvanWang/p/11084359.html
Copyright © 2011-2022 走看看