zoukankan      html  css  js  c++  java
  • tensorflow源码分析——BasicLSTMCell

    BasicLSTMCell 是最简单的LSTMCell,源码位于:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py。
    BasicLSTMCell 继承了RNNCell,源码位于:/tensorflow/python/ops/rnn_cell_impl.py
    注意事项:
    1. input_size 这个参数不能使用,使用的是num_units
    2. state_is_tuple 官方建议设置为True。此时,输入和输出的states为c(cell状态)和h(输出)的二元组
    3. 输入、输出、cell的维度相同,都是 batch_size * num_units,
    cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=0.0, state_is_tuple=True)  #指定num_units
    _initial_state = cell.zero_state(batch_size, tf.float32)                   #指定batch_size,将c和h全部初始化为0,shape全是batch_size * num_units,

    4.
    class BasicLSTMCell(RNNCell):
      """Basic LSTM recurrent network cell.
    
      The implementation is based on: http://arxiv.org/abs/1409.2329.
    
      We add forget_bias (default: 1) to the biases of the forget gate in order to
      reduce the scale of forgetting in the beginning of the training.
    
      It does not allow cell clipping, a projection layer, and does not
      use peep-hole connections: it is the basic baseline.
    
      For advanced models, please use the full LSTMCell that follows.
      """
    
      def __init__(self, num_units, forget_bias=1.0, input_size=None,
                   state_is_tuple=True, activation=tanh):
        """Initialize the basic LSTM cell.
    
        Args:
          num_units: int, The number of units in the LSTM cell.
          forget_bias: float, The bias added to forget gates (see above).
          input_size: Deprecated and unused.
          state_is_tuple: If True, accepted and returned states are 2-tuples of
            the `c_state` and `m_state`.  If False, they are concatenated
            along the column axis.  The latter behavior will soon be deprecated.
          activation: Activation function of the inner states.
        """
        if not state_is_tuple:
          logging.warn("%s: Using a concatenated state is slower and will soon be "
                       "deprecated.  Use state_is_tuple=True.", self)
        if input_size is not None:
          logging.warn("%s: The input_size parameter is deprecated.", self)
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation
    
      @property
      def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units)
                if self._state_is_tuple else 2 * self._num_units)
    
      @property
      def output_size(self):
        return self._num_units
    
      def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with vs.variable_scope(scope or "basic_lstm_cell"):
          # Parameters of gates are concatenated into one multiply for efficiency.
          if self._state_is_tuple:
            c, h = state
          else:
            c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

        # 线性计算 concat = [inputs, h]W + b
        # 线性计算,分配W和b,W的shape为(2*num_units, 4*num_units), b的shape为(4*num_units,),共包含有四套参数,
    # concat shape(batch_size, 4*num_units)
      # 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数 concat
    = _linear([inputs, h], 4 * self._num_units, True, scope=scope) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state

     5. lstm层,每一batch的运算

            with tf.variable_scope("RNN"):
                for time_step in range(num_steps):
                    if time_step > 0: tf.get_variable_scope().reuse_variables()
                    (cell_output, state) = cell(inputs[:, time_step, :], state)
                    outputs.append(cell_output)

    6. 每一epoch

    7.全部运算

  • 相关阅读:
    《你必须知道的495个C语言问题》读书笔记之第15-20章:浮点数、风格、杂项
    《你必须知道的495个C语言问题》读书笔记之第8-10章:字符串、布尔类型和预处理器
    《你必须知道的495个C语言问题》读书笔记之第4-7章:指针
    《你必须知道的495个C语言问题》读书笔记之第3章:表达式
    《你必须知道的495个C语言问题》读书笔记之第1-2章:声明和初始化
    bzoj4361 isn(树状数组优化dp+容斥)
    bzoj4665 小w的喜糖(dp+容斥)
    P4859 已经没有什么好害怕的了(dp+二项式反演)
    bzoj4710 [Jsoi2011]分特产(容斥)
    bzoj2839 集合计数(容斥)
  • 原文地址:https://www.cnblogs.com/yuetz/p/6563377.html
Copyright © 2011-2022 走看看