zoukankan      html  css  js  c++  java
  • Tensorflow --BeamSearch

     github:https://github.com/zle1992/Seq2Seq-Chatbot

    1、 注意在infer阶段,需要需要reuse,

     2、If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:

    • The encoder output has been tiled to beam_width via tf.contrib.seq2seq.tile_batch (NOT tf.tile).
    • The batch_size argument passed to the zero_state method of this wrapper is equal to true_batch_size * beam_width.
    • The initial state created with zero_state above contains a cell_state value containing properly tiled final state from the encoder.
     1 import tensorflow as tf
     2 from tensorflow.python.layers.core import Dense
     3 
     4 
     5 BEAM_WIDTH = 5
     6 BATCH_SIZE = 128
     7 
     8 
     9 # INPUTS
    10 X = tf.placeholder(tf.int32, [BATCH_SIZE, None])
    11 Y = tf.placeholder(tf.int32, [BATCH_SIZE, None])
    12 X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
    13 Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
    14 
    15 
    16 # ENCODER         
    17 encoder_out, encoder_state = tf.nn.dynamic_rnn(
    18     cell = tf.nn.rnn_cell.BasicLSTMCell(128), 
    19     inputs = tf.contrib.layers.embed_sequence(X, 10000, 128),
    20     sequence_length = X_seq_len,
    21     dtype = tf.float32)
    22 
    23 
    24 # DECODER COMPONENTS
    25 Y_vocab_size = 10000
    26 decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0))
    27 projection_layer = Dense(Y_vocab_size)
    28 
    29 
    30 # ATTENTION (TRAINING)
    31 with tf.variable_scope('shared_attention_mechanism'):
    32     attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    33         num_units = 128, 
    34         memory = encoder_out,
    35         memory_sequence_length = X_seq_len)
    36 
    37 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    38     cell = tf.nn.rnn_cell.BasicLSTMCell(128),
    39     attention_mechanism = attention_mechanism,
    40     attention_layer_size = 128)
    41 
    42 
    43 # DECODER (TRAINING)
    44 training_helper = tf.contrib.seq2seq.TrainingHelper(
    45     inputs = tf.nn.embedding_lookup(decoder_embedding, Y),
    46     sequence_length = Y_seq_len,
    47     time_major = False)
    48 training_decoder = tf.contrib.seq2seq.BasicDecoder(
    49     cell = decoder_cell,
    50     helper = training_helper,
    51     initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state),
    52     output_layer = projection_layer)
    53 with tf.variable_scope('decode_with_shared_attention'):
    54     training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
    55         decoder = training_decoder,
    56         impute_finished = True,
    57         maximum_iterations = tf.reduce_max(Y_seq_len))
    58 training_logits = training_decoder_output.rnn_output
    59 
    60 
    61 # BEAM SEARCH TILE
    62 encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH)
    63 X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH)
    64 encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=BEAM_WIDTH)
    65 
    66 
    67 # ATTENTION (PREDICTING)
    68 with tf.variable_scope('shared_attention_mechanism', reuse=True):
    69     attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    70         num_units = 128, 
    71         memory = encoder_out,
    72         memory_sequence_length = X_seq_len)
    73 
    74 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    75     cell = tf.nn.rnn_cell.BasicLSTMCell(128),
    76     attention_mechanism = attention_mechanism,
    77     attention_layer_size = 128)
    78 
    79 
    80 # DECODER (PREDICTING)
    81 predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    82     cell = decoder_cell,
    83     embedding = decoder_embedding,
    84     start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]),
    85     end_token = 2,
    86     initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state),
    87     beam_width = BEAM_WIDTH,
    88     output_layer = projection_layer,
    89     length_penalty_weight = 0.0)
    90 with tf.variable_scope('decode_with_shared_attention', reuse=True):
    91     predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
    92         decoder = predicting_decoder,
    93         impute_finished = False,
    94         maximum_iterations = 2 * tf.reduce_max(Y_seq_len))
    95 predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0]
    96 
    97 print('successful')

     参考:

    https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084

    https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder

    https://github.com/tensorflow/nmt#beam-search

  • 相关阅读:
    start tag, end tag issues in IE7, particularly in xslt transformation
    用SandCastle为注释生成chm文档
    Firebug
    架构的重点
    Linux Shell常用技巧(十) 管道组合
    Linux JDK升级
    Linux Shell常用技巧(十二) Shell编程
    Packet Tracer 5.0实验(一) 交换机的基本配置与管理
    Linux Shell常用技巧(六) sort uniq tar split
    Linux Shell常用技巧(二) grep
  • 原文地址:https://www.cnblogs.com/zle1992/p/10608376.html
Copyright © 2011-2022 走看看