zoukankan      html  css  js  c++  java
  • tensorflow源码分析——BasicLSTMCell

    BasicLSTMCell 是最简单的LSTMCell,源码位于:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py。
    BasicLSTMCell 继承了RNNCell,源码位于:/tensorflow/python/ops/rnn_cell_impl.py
    注意事项:
    1. input_size 这个参数不能使用,使用的是num_units
    2. state_is_tuple 官方建议设置为True。此时,输入和输出的states为c(cell状态)和h(输出)的二元组
    3. 输入、输出、cell的维度相同,都是 batch_size * num_units,
    cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=0.0, state_is_tuple=True)  #指定num_units
    _initial_state = cell.zero_state(batch_size, tf.float32)                   #指定batch_size,将c和h全部初始化为0,shape全是batch_size * num_units,

    4.
    class BasicLSTMCell(RNNCell):
      """Basic LSTM recurrent network cell.
    
      The implementation is based on: http://arxiv.org/abs/1409.2329.
    
      We add forget_bias (default: 1) to the biases of the forget gate in order to
      reduce the scale of forgetting in the beginning of the training.
    
      It does not allow cell clipping, a projection layer, and does not
      use peep-hole connections: it is the basic baseline.
    
      For advanced models, please use the full LSTMCell that follows.
      """
    
      def __init__(self, num_units, forget_bias=1.0, input_size=None,
                   state_is_tuple=True, activation=tanh):
        """Initialize the basic LSTM cell.
    
        Args:
          num_units: int, The number of units in the LSTM cell.
          forget_bias: float, The bias added to forget gates (see above).
          input_size: Deprecated and unused.
          state_is_tuple: If True, accepted and returned states are 2-tuples of
            the `c_state` and `m_state`.  If False, they are concatenated
            along the column axis.  The latter behavior will soon be deprecated.
          activation: Activation function of the inner states.
        """
        if not state_is_tuple:
          logging.warn("%s: Using a concatenated state is slower and will soon be "
                       "deprecated.  Use state_is_tuple=True.", self)
        if input_size is not None:
          logging.warn("%s: The input_size parameter is deprecated.", self)
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation
    
      @property
      def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units)
                if self._state_is_tuple else 2 * self._num_units)
    
      @property
      def output_size(self):
        return self._num_units
    
      def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with vs.variable_scope(scope or "basic_lstm_cell"):
          # Parameters of gates are concatenated into one multiply for efficiency.
          if self._state_is_tuple:
            c, h = state
          else:
            c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

        # 线性计算 concat = [inputs, h]W + b
        # 线性计算,分配W和b,W的shape为(2*num_units, 4*num_units), b的shape为(4*num_units,),共包含有四套参数,
    # concat shape(batch_size, 4*num_units)
      # 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数 concat
    = _linear([inputs, h], 4 * self._num_units, True, scope=scope) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state

     5. lstm层,每一batch的运算

            with tf.variable_scope("RNN"):
                for time_step in range(num_steps):
                    if time_step > 0: tf.get_variable_scope().reuse_variables()
                    (cell_output, state) = cell(inputs[:, time_step, :], state)
                    outputs.append(cell_output)

    6. 每一epoch

    7.全部运算

  • 相关阅读:
    The Mac Application Environment 不及格的程序员
    Xcode Plugin: Change Code In Running App Without Restart 不及格的程序员
    The property delegate of CALayer cause Crash. 不及格的程序员
    nil localizedTitle in SKProduct 不及格的程序员
    InApp Purchase 不及格的程序员
    Safari Web Content Guide 不及格的程序员
    在Mac OS X Lion 安装 XCode 3.2 不及格的程序员
    illustrate ARC with graphs 不及格的程序员
    Viewing iPhoneOptimized PNGs 不及格的程序员
    What is the dSYM? 不及格的程序员
  • 原文地址:https://www.cnblogs.com/yuetz/p/6563377.html
Copyright © 2011-2022 走看看