zoukankan      html  css  js  c++  java
  • MXNet中LSTM例子注记

    Preface

    序列问题也是一个interesting的issue.找了一会LSTM的材料,发现并没有一个系统的文字,早期Sepp Hochreiterpaper和弟子Felix Gersthesis看起来并没有那么轻松。最开始入手的是15年的一个review,当时看起来也不太顺畅,但看了前两个(一部分)再回头来看这篇的formulation部分,会清晰些。
    本来打算自己写个程序理一下,发现这里有个参考,程序很短,Python写的总共没有200line,但要从里面理出结构来有些费劲。想起MXNet里面好像有些例子(example/bi-lstm-sort),找出来查看。里面用symbol构建了LSTM基本单元,然后用bucket特性进行优化。感觉还不错,顺带可以看看bucket怎么用的。

    Code Plus Comment

    这段程序里面用symbol构建了记忆单元,然后用之构建了一个完整的symbol,之前以为是用了内建的的一个符号,但发现MXNet-V1.0版本上LSTM单元内建符号都还处于dev阶段,所以比较感兴趣的是怎么做到时序关联的。

    ######## from example/bi-lstm-sort/lstm.py #############
    def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):  # 构建一个单元
        """LSTM Cell symbol""" 
        if dropout > 0.:
            indata = mx.sym.Dropout(data=indata, p=dropout)
        i2h = mx.sym.FullyConnected(data=indata,
                                    weight=param.i2h_weight,
                                    bias=param.i2h_bias,
                                    num_hidden=num_hidden * 4,
                                    name="t%d_l%d_i2h" % (seqidx, layeridx))
        h2h = mx.sym.FullyConnected(data=prev_state.h,
                                    weight=param.h2h_weight,
                                    bias=param.h2h_bias,
                                    num_hidden=num_hidden * 4,
                                    name="t%d_l%d_h2h" % (seqidx, layeridx))
        gates = i2h + h2h
        slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                          name="t%d_l%d_slice" % (seqidx, layeridx))
        in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
        in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
        forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
        out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
        next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
        next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
        return LSTMState(c=next_c, h=next_h)
    
    def bi_lstm_unroll(seq_len, input_size,
                    num_hidden, num_embed, num_label, dropout=0.):
    
        embed_weight = mx.sym.Variable("embed_weight")
        cls_weight = mx.sym.Variable("cls_weight")
        cls_bias = mx.sym.Variable("cls_bias")
        last_states = []
        last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")))
        last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h")))
        forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
                                  i2h_bias=mx.sym.Variable("c"),
                                  h2h_weight=mx.sym.Variable("l0_h2h_weight"),
                                  h2h_bias=mx.sym.Variable("l0_h2h_bias"))
        backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
                                  i2h_bias=mx.sym.Variable("l1_i2h_bias"),
                                  h2h_weight=mx.sym.Variable("l1_h2h_weight"),
                                  h2h_bias=mx.sym.Variable("l1_h2h_bias"))
    
        # embeding layer
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=input_size,
                                 weight=embed_weight, output_dim=num_embed, name='embed')
        wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
    
        forward_hidden = []
        for seqidx in range(seq_len):
            hidden = wordvec[seqidx]
            next_state = lstm(num_hidden, indata=hidden,
                              prev_state=last_states[0],
                              param=forward_param,
                              seqidx=seqidx, layeridx=0, dropout=dropout)
            hidden = next_state.h
            last_states[0] = next_state
            forward_hidden.append(hidden)
    
        backward_hidden = []  # 从文件夹名字看得出来,这是个双向的符号,所以会有 backward部分(just a guess :) )
        for seqidx in range(seq_len):
            k = seq_len - seqidx - 1
            hidden = wordvec[k]
            next_state = lstm(num_hidden, indata=hidden,
                              prev_state=last_states[1],
                              param=backward_param,
                              seqidx=k, layeridx=1,dropout=dropout)
            hidden = next_state.h
            last_states[1] = next_state
            backward_hidden.insert(0, hidden)
    
        hidden_all = []
        for i in range(seq_len):
            hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
    
        hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
        pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
                                     weight=cls_weight, bias=cls_bias, name='pred')
    
        label = mx.sym.transpose(data=label)
        label = mx.sym.Reshape(data=label, target_shape=(0,))
        sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
    
        return sm
    

    Embedding Op

    之前翻API的时候,就看到过这个符号,当时虽然明白了可以实现什么功能(虽然也明白错了:以为只能实现引索/编码),但想不到这种功能拿来可以做什么-_-||。现在遇上了,顺带就查查看。API没有附上相应的文献说明这个实现参照的什么,只好到处搜。知乎上的回答应该可以帮助到。整理一下,就是说,用非one-hot编码的方式,对词表进行编码(消除各维数之间的独立性,各维数取值是连续的)。

    note

    1. 这带来的另一个问题是如何进行优化?这个后面把paper看了再看情况吧(只是猜测这一层应该要实现update)。
    2. 另一个问题是,如果不使用one-hot编码,在模型输出阶段,如何进行解码?从程序上来看,在输出阶段,softmax输出的会被认为是one-hot编码,从而避免这个问题。

    Time Dependency

    这个例子里面只使用了一层的记忆单元(但根据paper上的情况来看,即使只有一个单元,体型也是很大的)。完整符号的构建是在bi_lstm_unroll里面进行的,其时序依赖关系的建立方案如下。先将输入的单个完整序列(向量序列)用SliceChannel分离成为单个的向量,然后按照分离出的向量个数构建一个完整的符号,由于此时已经知道向量的个数,可以不断地堆积记忆单元,直到每个向量都分配到了对应的处理单元。每个单元使用的参数被指定为同一组(l0_i2h_weight, l0_i2h_bias, etc.)。这样就实现了效果上的循环计算。

    note

    1. 此处产生的另一个问题是,如何处理变长度的输入序列问题。这应该与bucket机制有关,后面找时间看看去。但可以猜测下bucket要解决的问题,从大神的blog看,bucket机制要对每个设定好的长度绑定生成一个模型,并且由于这些长度都是离散的,可能还要进行补齐的操作,如果进行了补齐那么还要处理由此产生的训练更新问题。
    2. 另一个问题是,多个节点使用同一组参数,进行backwardupdate过程时,参数是如何更新的。这里先放些推测。之前有看到过grad的操作可以有null, write, add似乎默认是write(overwrite);grad的应该是按照节点为单位分配。所以参数的更新会是,那一组被引用的参数在不同节点处,按照当地的backward计算结果进行更新操作。这个结论下的更新操作看上去是合理的。

    LSTM Implementation

    来看看怎么构建一个记忆单元的吧,过一段时间内建版本发布了,说不定这个例子也像rcnn一样看不到手撕的细节了(好吧,至少不那么容易)。
    lstmfunction里面实现的是A Critical Review of Recurrent Neural Networks for Sequence Learning Page-20上的式子,并不是Felix Gers thesis Page-17Figure 3.1描述的形式,关于这一点,前者在那页上有段注记:These equations give the full algorithm for a modern LSTM ...。我还是把式子打一遍吧...

    [egin{eqnarray} g^{(t)} &=& phi (W^{gx}x^{(t)} + W^{gh}h^{(t-1)} + b_g) onumber\ i^{(t)} &=& phi (W^{ix}x^{(t)} + W^{ih}h^{(t-1)} + b_i) onumber\ f^{(t)} &=& phi (W^{fx}x^{(t)} + W^{fh}h^{(t-1)} + b_f) onumber\ o^{(t)} &=& phi (W^{ox}x^{(t)} + W^{oh}h^{(t-1)} + b_o) onumber\ s^{(t)} &=& g^{(t)}odot i^{(t)} + s^{(t-1)}odot f^{(t)} onumber\ h^{(t)} &=& phi (s^{(t)}) odot o^{(t)} onumber end{eqnarray} ]

    可以观察到,每个非线性映射的输入变量都是相同的((x^{(t)},~h^{(t-1)})),对应到lstm function里面,i2hh2h被直接加起来,然后再分为相应的gate参数。

    Graph

    说了这么些,再来看看最后生成的图是怎样的(图有些大,右键单独查看为好):

    Figure 1. Graph of the *LSTM* for 5-length input
    可以观察到,底层部分除了`data`节点以外,还存在有青色节点,按照这个命名方式是不能被初始化的,在*sort_io.py*里面为这些节点提供了参数。从这个图里面也可以窥测到,lstm的计算密度是很大的。 # Note 最后在附上一个注记吧,程序虽然是以`sort`命名的,但从内容上看,这样的训练是将每个数字作为一个单词输入进去的,也就是说,测试的时候输入的数字序列也必须是训练时出现过的(没严格验证过,猜测啦)
  • 相关阅读:
    《DSP using MATLAB》 示例 Example 9.12
    《DSP using MATLAB》示例 Example 9.11
    《DSP using MATLAB》示例 Example 9.10
    《DSP using MATLAB》示例Example 9.9
    《DSP using MATLAB》示例 Example 9.8
    《DSP using MATLAB》示例Example 9.7
    《DSP using MATLAB》示例 Example 9.6
    《DSP using MATLAB》示例Example 9.5
    《DSP using MATLAB》示例 Example 9.4
    (转载)【C++11新特性】 nullptr关键字
  • 原文地址:https://www.cnblogs.com/chenyliang/p/8053424.html
Copyright © 2011-2022 走看看