tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现
函数文档:https://www.tensorflow.org/api_docs/python/tf/contrib/legacy_seq2seq/basic_rnn_seq2seq
import tensorflow as tf
import numpy as np
steps=10
batch_size=10
input_size=10
encoder_inputs = tf.placeholder("float", [None, steps, input_size])
decoder_inputs = tf.placeholder("float", [None, steps, input_size])
en_input=np.zeros(shape=[steps,batch_size,input_size])
de_input=np.zeros(shape=[steps,batch_size,input_size])
cell=tf.nn.rnn_cell.BasicLSTMCell(10)
def get_result(encoder_inputs,decoder_inputs,cell):
encoder_inputs=tf.unstack(encoder_inputs,axis=1)
decoder_inputs=tf.unstack(decoder_inputs,axis=1)
result=tf.contrib.legacy_seq2seq.basic_rnn_seq2seq(
encoder_inputs,
decoder_inputs,
cell,
dtype=tf.float32,
scope=None
)
return result
result=get_result(encoder_inputs,decoder_inputs,cell)
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
result_value=sess.run(result,feed_dict={encoder_inputs:en_input,decoder_inputs:de_input})
print(result_value)