一、概述
理解了LSTM之后,GRU就很好理解。
首先GRU有两个门:
reset gate 重置门 (r_t):用于控制前一时刻的隐含层状态有多大程度更新到当前候选隐含层状态:
update gate 更新门(z_t):用于控制前一时刻的隐含层状态有多大程度更新到当前隐含层状态:
两个隐藏层:
候选隐藏层( ilde{h}_t),这个候选隐藏层 和LSTM中的 (c_t)是类似,可以看成是当前时刻的新信息,其中(r_t)用来控制需要保留多少之前的记忆:
( ilde{h}_t) 就是GRU记录到的所有重要信息,表示当前记忆内容比如在语言模型中,可能保存了主语单复数,主语的性别,当前时态等所有记录的重要信息。
隐藏层(h_t)。最后(z_t)控制需要从前一时刻的隐藏层(h_{t-1})中遗忘多少信息,需要加入多少当前时刻的隐藏层信息( ilde{h}_t)。(h_t)即为最后输出的隐藏层信息:
需要注意的是,虽然隐藏层信息的符号和当前记忆内容的符号相似,但是这两者是有一定的区别的。当前记忆内容在上文中已经说明了是当前时刻保存的所有信息,而隐藏层信息则是当前时刻所需要的信息。
比如在语言模型中,在当前时刻可能我们只需要知道当前时态和主语单复数就可以确定当前动词使用什么时态,而不需要其他更多的信息。
一般来说那些具有短距离依赖的单元reset gate比较活跃(如果(r_t)为1,而(z_T)为0 那么相当于变成了一个标准的RNN,能处理短距离依赖),具有长距离依赖的单元update gate比较活跃。
二、一个示例
tensorflow有两个类实现了GRU
- tf.contrib.rnn.GRUCell
- tf.nn.rnn_cell.GRUCell
import tensorflow as tf
batch_size=10
depth=128
output_dim=100
inputs=tf.Variable(tf.random_normal([batch_size,depth]))
previous_state=tf.Variable(tf.random_normal([batch_size,output_dim])) #前一个状态的输出
gruCell=tf.nn.rnn_cell.GRUCell(output_dim) # 隐层神经元个数为output_dim
output,state=gruCell(inputs,previous_state)
print(output) # shape=(10, 100),
print(state) # shape=(10, 100) 返回相同的值
GRU的输出和LSTM的区别:
GRU返回值output 和 state具有相同的值。
参考资料
1、深入理解lstm及其变种gru
https://zhuanlan.zhihu.com/p/34203833