zoukankan      html  css  js  c++  java
  • TF_RNNCell

    参考:链接

    RNNCell

    • BasicRNNCell
    • GRUCell
    • BasicLSTMCell
    • LSTMCell
    • MultiRNNCell

    抽象类RNNCell

    所有的rnncell均继承于RNNCell, RNNCell主要定义了几个抽象方法:

     1 def __call__(self, inputs, state, scope=None):
     2     raise NotImplementedError("Abstract method")
     3 
     4 @property
     5 def state_size(self):
     6     raise NotImplementedError("Abstract method")
     7 
     8 @property
     9 def output_size(self):
    10     raise NotImplementedError("Abstract method")

    上述方法,__call__在对象被使用时调用,其他可以看做属性方法,主要用作获取状态state的大小,cell的输出大小。既然对象使用时会调用__call__,那么各类RNN的操作都定义在这个方法中。接下来,我们就针对各个不同的cell来详细介绍各类RNN。

    BasicRNNCell

    这个cell是最基础的一个RNNCell,可以看做是对一般全连接层的拓展,除了在水平方向加入时序关系,可以用下图表示:

     而BasicRNNCell的初始化方法可如代码所示:

    1 def __init__(self, num_units, input_size=None, activation=tanh):
    2     if input_size is not None:
    3       logging.warn("%s: The input_size parameter is deprecated.", self)
    4     self._num_units = num_units
    5     self._activation = activation

    初始化只需要给出num_units,用来指有多少个隐藏层单元;而activation指使用哪种激活函数用作激活输出。而对应的RNN操作定义在__call__方法中:

    1 def __call__(self, inputs, state, scope=None):
    2     """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
    3     with vs.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
    4       output = self._activation(_linear([inputs, state], self._num_units, True))
    5     return output, output

    很清晰,inputs表示隐藏层的输入,state表示上个时间的隐藏层状态,也可以说是上一次隐藏层向自身的输出,对于第一次输入,则需要初始化state,对应初始化方法有很多种,可以使用tensorflow提供的各种初始化函数。在__call__中,对输入inputsstate进行activation(wx+b),用作下次的输入。

    GRUCell

    GRU是对RNN的一种改进,相比LSTM来说,也可以看做是对LSTM的一种简化,是Bengio在14年提出来的,用作机器翻译。先看一下GRU的基本结构:

    这里我们结合代码来看原理:

    def __call__(self, inputs, state, scope=None):
      """Gated recurrent unit (GRU) with nunits cells."""
      with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
        with vs.variable_scope("Gates"):  # Reset gate and update gate.
          # We start with bias of 1.0 to not reset and not update.
          r, u = array_ops.split(1, 2, _linear([inputs, state],
                                               2 * self._num_units, True, 1.0))
          r, u = sigmoid(r), sigmoid(u)
        with vs.variable_scope("Candidate"):
          c = self._activation(_linear([inputs, r * state],
                                       self._num_units, True))
        new_h = u * state + (1 - u) * c
      return new_h, new_h

    GRUCell的初始化与RNN一样,给出输入和初始化的state,在使用对象时,利用输入和前一个时间的隐藏层状态,得到对应的Gates: r, u, 然后利用r更新cell状态,最后利用u得到新的隐藏层状态。对于RNN的改进,最厉害的莫过于下面的,而且有很多变种,这里tensorflow中只有几个简单常见的cell。接下来,我们开始看看LSTM。

    BasicLSTMCell

    这个cell可以看做是最简单的LSTM,在每个连接中没有额外的连接,即其他变种在连接中加入各种改进。对于BasicLSTMCell,可以如下图所示:

    同样的,我们结合代码来看它的原理:

     1 def __call__(self, inputs, state, scope=None):
     2   """Long short-term memory cell (LSTM)."""
     3   with vs.variable_scope(scope or type(self).__name__):  # "BasicLSTMCell"
     4     # Parameters of gates are concatenated into one multiply for efficiency.
     5     if self._state_is_tuple:
     6       c, h = state
     7     else:
     8       c, h = array_ops.split(1, 2, state)
     9     concat = _linear([inputs, h], 4 * self._num_units, True)
    10 
    11     # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    12     i, j, f, o = array_ops.split(1, 4, concat)
    13 
    14     new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
    15              self._activation(j))
    16     new_h = self._activation(new_c) * sigmoid(o)
    17 
    18     if self._state_is_tuple:
    19       new_state = LSTMStateTuple(new_c, new_h)
    20     else:
    21       new_state = array_ops.concat(1, [new_c, new_h])
    22     return new_h, new_state

    lstm有三个门,inputs, forget, output, 而中间cell用来管理结合他们生产需要的输出。在初始化结束之后,利用输入分别得到对应的门的输出,然后利用这三个门的信息分别更新cell和当前隐藏层状态。f 用来控制遗忘之前的信息和记忆当前信息的比例,进而更新cell,lstm可以看做是一种复杂的激活函数,它的存在依赖RNN的递归性。BasicLSTMCell只是个最基本的LSTM,而完整的LSTM可能比这个复杂,可以参看blog

    MultiRNNCell

    对于MultiRNNCell,只能贴出完整代码来分析了:

     1 class MultiRNNCell(RNNCell):
     2   """RNN cell composed sequentially of multiple simple cells."""
     3 
     4   def __init__(self, cells, state_is_tuple=False):
     5     """Create a RNN cell composed sequentially of a number of RNNCells.
     6 
     7     Args:
     8       cells: list of RNNCells that will be composed in this order.
     9       state_is_tuple: If True, accepted and returned states are n-tuples, where
    10         `n = len(cells)`.  By default (False), the states are all
    11         concatenated along the column axis.
    12 
    13     Raises:
    14       ValueError: if cells is empty (not allowed), or at least one of the cells
    15         returns a state tuple but the flag `state_is_tuple` is `False`.
    16     """
    17     if not cells:
    18       raise ValueError("Must specify at least one cell for MultiRNNCell.")
    19     self._cells = cells
    20     self._state_is_tuple = state_is_tuple
    21     if not state_is_tuple:
    22       if any(nest.is_sequence(c.state_size) for c in self._cells):
    23         raise ValueError("Some cells return tuples of states, but the flag "
    24                          "state_is_tuple is not set.  State sizes are: %s"
    25                          % str([c.state_size for c in self._cells]))
    26 
    27   @property
    28   def state_size(self):
    29     if self._state_is_tuple:
    30       return tuple(cell.state_size for cell in self._cells)
    31     else:
    32       return sum([cell.state_size for cell in self._cells])
    33 
    34   @property
    35   def output_size(self):
    36     return self._cells[-1].output_size
    37 
    38   def __call__(self, inputs, state, scope=None):
    39     """Run this multi-layer cell on inputs, starting from state."""
    40     with vs.variable_scope(scope or type(self).__name__):  # "MultiRNNCell"
    41       cur_state_pos = 0
    42       cur_inp = inputs
    43       new_states = []
    44       for i, cell in enumerate(self._cells):
    45         with vs.variable_scope("Cell%d" % i):
    46           if self._state_is_tuple:
    47             if not nest.is_sequence(state):
    48               raise ValueError(
    49                   "Expected state to be a tuple of length %d, but received: %s"
    50                   % (len(self.state_size), state))
    51             cur_state = state[i]
    52           else:
    53             cur_state = array_ops.slice(
    54                 state, [0, cur_state_pos], [-1, cell.state_size])
    55             cur_state_pos += cell.state_size
    56           cur_inp, new_state = cell(cur_inp, cur_state)
    57           new_states.append(new_state)
    58     new_states = (tuple(new_states) if self._state_is_tuple
    59                   else array_ops.concat(1, new_states))
    60     return cur_inp, new_states

    创建对象时,可以看到初始化函数中不再是输入,而是变成了cells,,即一个cell是一层,多个cell便有多层RNNcell。而在使用对象时,单层可以看做多层的特例,对于输入inputs和state,同时得到多个cell的当前隐藏层状态,用作下个时间步。看似麻烦,其实很简洁,就是加入了对多个cell的计算,最后得到的新的隐藏层状态即每个cell的上个时间步的输出。

  • 相关阅读:
    上机课前准备
    C# 跨线程访问控件的解决方法
    C# 异步和多线程的区别
    C# BeginInvoke和EndInvoke方法
    C# 委托的回调机制
    C# 什么是委托
    C# WebBrowser的控件使用
    C# 解决WebBrowser.DocumentCompleted事件的多次调用
    平台调用数据类型
    C# 创建快捷方式
  • 原文地址:https://www.cnblogs.com/niuxichuan/p/9152857.html
Copyright © 2011-2022 走看看