运行以下代码,进入~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py和~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py单步追踪调试
调试中import tensorflow as tf,利用tf.Session().run(variable)打印变量
查看BasicRNNCell和dynamic_rnn的实现方式
1 #-*-coding:utf8-*- 2 3 __author = "buyizhiyou" 4 __date = "2017-11-20" 5 6 ''' 7 单步调试,学习rnn的tf实现 8 ''' 9 import tensorflow as tf 10 import numpy as np 11 import pdb 12 13 X = tf.random_normal(shape=[2,3,4], dtype=tf.float32)#(2,3,4)==>(Batch_size,Time_steps(序列长度),Data_Vector) 14 pdb.set_trace() 15 cell = tf.nn.rnn_cell.BasicRNNCell(10)#output_size:10,也可以换成GRUCell,LSTMAACell,BasicRNNCell 16 state = cell.zero_state(2, tf.float32)#batch_size:2 17 output, state = tf.nn.dynamic_rnn(cell, X, initial_state=state, time_major=False) 18 with tf.Session() as sess: 19 sess.run(tf.global_variables_initializer()) 20 print (output.get_shape()) 21 print (sess.run(state))