zoukankan      html  css  js  c++  java
  • MXNet中LSTM例子注记

    Preface

    序列问题也是一个interesting的issue.找了一会LSTM的材料,发现并没有一个系统的文字,早期Sepp Hochreiterpaper和弟子Felix Gersthesis看起来并没有那么轻松。最开始入手的是15年的一个review,当时看起来也不太顺畅,但看了前两个(一部分)再回头来看这篇的formulation部分,会清晰些。
    本来打算自己写个程序理一下,发现这里有个参考,程序很短,Python写的总共没有200line,但要从里面理出结构来有些费劲。想起MXNet里面好像有些例子(example/bi-lstm-sort),找出来查看。里面用symbol构建了LSTM基本单元,然后用bucket特性进行优化。感觉还不错,顺带可以看看bucket怎么用的。

    Code Plus Comment

    这段程序里面用symbol构建了记忆单元,然后用之构建了一个完整的symbol,之前以为是用了内建的的一个符号,但发现MXNet-V1.0版本上LSTM单元内建符号都还处于dev阶段,所以比较感兴趣的是怎么做到时序关联的。

    ######## from example/bi-lstm-sort/lstm.py #############
    def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):  # 构建一个单元
        """LSTM Cell symbol""" 
        if dropout > 0.:
            indata = mx.sym.Dropout(data=indata, p=dropout)
        i2h = mx.sym.FullyConnected(data=indata,
                                    weight=param.i2h_weight,
                                    bias=param.i2h_bias,
                                    num_hidden=num_hidden * 4,
                                    name="t%d_l%d_i2h" % (seqidx, layeridx))
        h2h = mx.sym.FullyConnected(data=prev_state.h,
                                    weight=param.h2h_weight,
                                    bias=param.h2h_bias,
                                    num_hidden=num_hidden * 4,
                                    name="t%d_l%d_h2h" % (seqidx, layeridx))
        gates = i2h + h2h
        slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
                                          name="t%d_l%d_slice" % (seqidx, layeridx))
        in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
        in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
        forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
        out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
        next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
        next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
        return LSTMState(c=next_c, h=next_h)
    
    def bi_lstm_unroll(seq_len, input_size,
                    num_hidden, num_embed, num_label, dropout=0.):
    
        embed_weight = mx.sym.Variable("embed_weight")
        cls_weight = mx.sym.Variable("cls_weight")
        cls_bias = mx.sym.Variable("cls_bias")
        last_states = []
        last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")))
        last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h")))
        forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"),
                                  i2h_bias=mx.sym.Variable("c"),
                                  h2h_weight=mx.sym.Variable("l0_h2h_weight"),
                                  h2h_bias=mx.sym.Variable("l0_h2h_bias"))
        backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"),
                                  i2h_bias=mx.sym.Variable("l1_i2h_bias"),
                                  h2h_weight=mx.sym.Variable("l1_h2h_weight"),
                                  h2h_bias=mx.sym.Variable("l1_h2h_bias"))
    
        # embeding layer
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=input_size,
                                 weight=embed_weight, output_dim=num_embed, name='embed')
        wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1)
    
        forward_hidden = []
        for seqidx in range(seq_len):
            hidden = wordvec[seqidx]
            next_state = lstm(num_hidden, indata=hidden,
                              prev_state=last_states[0],
                              param=forward_param,
                              seqidx=seqidx, layeridx=0, dropout=dropout)
            hidden = next_state.h
            last_states[0] = next_state
            forward_hidden.append(hidden)
    
        backward_hidden = []  # 从文件夹名字看得出来,这是个双向的符号,所以会有 backward部分(just a guess :) )
        for seqidx in range(seq_len):
            k = seq_len - seqidx - 1
            hidden = wordvec[k]
            next_state = lstm(num_hidden, indata=hidden,
                              prev_state=last_states[1],
                              param=backward_param,
                              seqidx=k, layeridx=1,dropout=dropout)
            hidden = next_state.h
            last_states[1] = next_state
            backward_hidden.insert(0, hidden)
    
        hidden_all = []
        for i in range(seq_len):
            hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
    
        hidden_concat = mx.sym.Concat(*hidden_all, dim=0)
        pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label,
                                     weight=cls_weight, bias=cls_bias, name='pred')
    
        label = mx.sym.transpose(data=label)
        label = mx.sym.Reshape(data=label, target_shape=(0,))
        sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
    
        return sm
    

    Embedding Op

    之前翻API的时候,就看到过这个符号,当时虽然明白了可以实现什么功能(虽然也明白错了:以为只能实现引索/编码),但想不到这种功能拿来可以做什么-_-||。现在遇上了,顺带就查查看。API没有附上相应的文献说明这个实现参照的什么,只好到处搜。知乎上的回答应该可以帮助到。整理一下,就是说,用非one-hot编码的方式,对词表进行编码(消除各维数之间的独立性,各维数取值是连续的)。

    note

    1. 这带来的另一个问题是如何进行优化?这个后面把paper看了再看情况吧(只是猜测这一层应该要实现update)。
    2. 另一个问题是,如果不使用one-hot编码,在模型输出阶段,如何进行解码?从程序上来看,在输出阶段,softmax输出的会被认为是one-hot编码,从而避免这个问题。

    Time Dependency

    这个例子里面只使用了一层的记忆单元(但根据paper上的情况来看,即使只有一个单元,体型也是很大的)。完整符号的构建是在bi_lstm_unroll里面进行的,其时序依赖关系的建立方案如下。先将输入的单个完整序列(向量序列)用SliceChannel分离成为单个的向量,然后按照分离出的向量个数构建一个完整的符号,由于此时已经知道向量的个数,可以不断地堆积记忆单元,直到每个向量都分配到了对应的处理单元。每个单元使用的参数被指定为同一组(l0_i2h_weight, l0_i2h_bias, etc.)。这样就实现了效果上的循环计算。

    note

    1. 此处产生的另一个问题是,如何处理变长度的输入序列问题。这应该与bucket机制有关,后面找时间看看去。但可以猜测下bucket要解决的问题,从大神的blog看,bucket机制要对每个设定好的长度绑定生成一个模型,并且由于这些长度都是离散的,可能还要进行补齐的操作,如果进行了补齐那么还要处理由此产生的训练更新问题。
    2. 另一个问题是,多个节点使用同一组参数,进行backwardupdate过程时,参数是如何更新的。这里先放些推测。之前有看到过grad的操作可以有null, write, add似乎默认是write(overwrite);grad的应该是按照节点为单位分配。所以参数的更新会是,那一组被引用的参数在不同节点处,按照当地的backward计算结果进行更新操作。这个结论下的更新操作看上去是合理的。

    LSTM Implementation

    来看看怎么构建一个记忆单元的吧,过一段时间内建版本发布了,说不定这个例子也像rcnn一样看不到手撕的细节了(好吧,至少不那么容易)。
    lstmfunction里面实现的是A Critical Review of Recurrent Neural Networks for Sequence Learning Page-20上的式子,并不是Felix Gers thesis Page-17Figure 3.1描述的形式,关于这一点,前者在那页上有段注记:These equations give the full algorithm for a modern LSTM ...。我还是把式子打一遍吧...

    [egin{eqnarray} g^{(t)} &=& phi (W^{gx}x^{(t)} + W^{gh}h^{(t-1)} + b_g) onumber\ i^{(t)} &=& phi (W^{ix}x^{(t)} + W^{ih}h^{(t-1)} + b_i) onumber\ f^{(t)} &=& phi (W^{fx}x^{(t)} + W^{fh}h^{(t-1)} + b_f) onumber\ o^{(t)} &=& phi (W^{ox}x^{(t)} + W^{oh}h^{(t-1)} + b_o) onumber\ s^{(t)} &=& g^{(t)}odot i^{(t)} + s^{(t-1)}odot f^{(t)} onumber\ h^{(t)} &=& phi (s^{(t)}) odot o^{(t)} onumber end{eqnarray} ]

    可以观察到,每个非线性映射的输入变量都是相同的((x^{(t)},~h^{(t-1)})),对应到lstm function里面,i2hh2h被直接加起来,然后再分为相应的gate参数。

    Graph

    说了这么些,再来看看最后生成的图是怎样的(图有些大,右键单独查看为好):

    Figure 1. Graph of the *LSTM* for 5-length input
    可以观察到,底层部分除了`data`节点以外,还存在有青色节点,按照这个命名方式是不能被初始化的,在*sort_io.py*里面为这些节点提供了参数。从这个图里面也可以窥测到,lstm的计算密度是很大的。 # Note 最后在附上一个注记吧,程序虽然是以`sort`命名的,但从内容上看,这样的训练是将每个数字作为一个单词输入进去的,也就是说,测试的时候输入的数字序列也必须是训练时出现过的(没严格验证过,猜测啦)
  • 相关阅读:
    Java过滤器与SpringMVC拦截器之间的关系与区别
    Linux分区,并且把新的分区挂载到指定的文件夹
    HibernateTemplate和HibernateDaoSupport(spring注入问题)
    EJB到底是什么,真的那么神秘吗??
    Hibernate关联关系配置(一对多、一对一和多对多)
    Druid的使用步骤
    Spring MVC静态资源处理(在applicationContex.xml文件中进行配置)
    Spring <context:annotation-config/> 解说
    hibernate的五大接口
    oracle11g数据库的安装以及安装之后的配置
  • 原文地址:https://www.cnblogs.com/chenyliang/p/8053424.html
Copyright © 2011-2022 走看看