zoukankan      html  css  js  c++  java
  • Attribute 'num_units' in Tensorflow BasicLSTMCell blocks

    在之前使用Tensorflow来做音乐识别时,LSTM给出了非常让人惊喜的学习能力。当时在进行Tuning的时候,有一个参数叫做num_units,字面看来是LTSM单元的个数,但最近当我试图阅读Tensorflow源代码时,和我们最初的认知大不相同,以此博文来记录。

    先看当初我们是如何设置的:

    rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=300)
    

    看起来像是,为Hidden Layer设置了300个单独的LSTM单元,然后并行工作最终输出300个值。但实际上,我们来看一下Tensorflow的源码:(github地址),从line 326,开始定义BasicLSTMCell类,在line 374行开始定义BasicLSTMCell的核心方法call方法:

     def call(self, inputs, state):
        """Long short-term memory cell (LSTM)."""
        sigmoid = math_ops.sigmoid
        # Parameters of gates are concatenated into one multiply for efficiency.
        if self._state_is_tuple:
          c, h = state
        else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
    
        concat = _linear([inputs, h], 4 * self._num_units, True)
    
        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
    
        new_c = (
            c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
        new_h = self._activation(new_c) * sigmoid(o)
    
        if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
        else:
          new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state
    

    注意13行,改行的作用是,根据当前时刻的输入inputs,以及前一时刻的输出值h,去计算4个gates在经过activation function之前的线性组合值。而后15-17两行,我们使用四个gates去计算了新的LSTM Cell状态c,以及新的输出值h。

    是的,无论num_units设置为多少,这是一个LSTM Cell!如果我们查看_linear这个函数,可以看到第二个参数是output_size,也就是说num_units和LSTM Cell的输出大小有关。事实上,Tensorflow的LSTMCell表征了整个一层Hidden Layer。而num_units则表示State Cell的存储能力,或者说维度Dimension。试想在一个LSTM Neural Network中,输入tensor X的维度是确定的,输出值Y的维度也是确定的,而LSTM各个时刻间的中间状态c,以及抽象输出h,则可以为任意维度。因为h可以经过dense层(fully-connected layer)去压缩成Y所需的维度。

    所以c和h的维度越高,其蕴含的time series data细节越多,当然越容易去拟合training set。但是,容易Overfitting呀,所以tuning时平衡training set的拟合程度,以及cv set的预测精度,来达到trade off咯。

  • 相关阅读:
    python字典
    python中List添加、删除元素的几种方法
    python数据处理之基本函数
    python批量处理
    python正则表达式
    python模块学习:os模块
    Hough transform(霍夫变换)
    MODBUS TCP/IP协议规范详细介绍
    Linux下run文件的直接运行
    双边滤波和引导滤波的原理
  • 原文地址:https://www.cnblogs.com/rhyswang/p/9163628.html
Copyright © 2011-2022 走看看