zoukankan      html  css  js  c++  java
  • Tensorflow --BeamSearch

     github:https://github.com/zle1992/Seq2Seq-Chatbot

    1、 注意在infer阶段,需要需要reuse,

     2、If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:

    • The encoder output has been tiled to beam_width via tf.contrib.seq2seq.tile_batch (NOT tf.tile).
    • The batch_size argument passed to the zero_state method of this wrapper is equal to true_batch_size * beam_width.
    • The initial state created with zero_state above contains a cell_state value containing properly tiled final state from the encoder.
     1 import tensorflow as tf
     2 from tensorflow.python.layers.core import Dense
     3 
     4 
     5 BEAM_WIDTH = 5
     6 BATCH_SIZE = 128
     7 
     8 
     9 # INPUTS
    10 X = tf.placeholder(tf.int32, [BATCH_SIZE, None])
    11 Y = tf.placeholder(tf.int32, [BATCH_SIZE, None])
    12 X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
    13 Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
    14 
    15 
    16 # ENCODER         
    17 encoder_out, encoder_state = tf.nn.dynamic_rnn(
    18     cell = tf.nn.rnn_cell.BasicLSTMCell(128), 
    19     inputs = tf.contrib.layers.embed_sequence(X, 10000, 128),
    20     sequence_length = X_seq_len,
    21     dtype = tf.float32)
    22 
    23 
    24 # DECODER COMPONENTS
    25 Y_vocab_size = 10000
    26 decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0))
    27 projection_layer = Dense(Y_vocab_size)
    28 
    29 
    30 # ATTENTION (TRAINING)
    31 with tf.variable_scope('shared_attention_mechanism'):
    32     attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    33         num_units = 128, 
    34         memory = encoder_out,
    35         memory_sequence_length = X_seq_len)
    36 
    37 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    38     cell = tf.nn.rnn_cell.BasicLSTMCell(128),
    39     attention_mechanism = attention_mechanism,
    40     attention_layer_size = 128)
    41 
    42 
    43 # DECODER (TRAINING)
    44 training_helper = tf.contrib.seq2seq.TrainingHelper(
    45     inputs = tf.nn.embedding_lookup(decoder_embedding, Y),
    46     sequence_length = Y_seq_len,
    47     time_major = False)
    48 training_decoder = tf.contrib.seq2seq.BasicDecoder(
    49     cell = decoder_cell,
    50     helper = training_helper,
    51     initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state),
    52     output_layer = projection_layer)
    53 with tf.variable_scope('decode_with_shared_attention'):
    54     training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
    55         decoder = training_decoder,
    56         impute_finished = True,
    57         maximum_iterations = tf.reduce_max(Y_seq_len))
    58 training_logits = training_decoder_output.rnn_output
    59 
    60 
    61 # BEAM SEARCH TILE
    62 encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH)
    63 X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH)
    64 encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=BEAM_WIDTH)
    65 
    66 
    67 # ATTENTION (PREDICTING)
    68 with tf.variable_scope('shared_attention_mechanism', reuse=True):
    69     attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    70         num_units = 128, 
    71         memory = encoder_out,
    72         memory_sequence_length = X_seq_len)
    73 
    74 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    75     cell = tf.nn.rnn_cell.BasicLSTMCell(128),
    76     attention_mechanism = attention_mechanism,
    77     attention_layer_size = 128)
    78 
    79 
    80 # DECODER (PREDICTING)
    81 predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    82     cell = decoder_cell,
    83     embedding = decoder_embedding,
    84     start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]),
    85     end_token = 2,
    86     initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state),
    87     beam_width = BEAM_WIDTH,
    88     output_layer = projection_layer,
    89     length_penalty_weight = 0.0)
    90 with tf.variable_scope('decode_with_shared_attention', reuse=True):
    91     predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
    92         decoder = predicting_decoder,
    93         impute_finished = False,
    94         maximum_iterations = 2 * tf.reduce_max(Y_seq_len))
    95 predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0]
    96 
    97 print('successful')

     参考:

    https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084

    https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder

    https://github.com/tensorflow/nmt#beam-search

  • 相关阅读:
    C#中抽象类和接口
    ArcGIS for Flex中引入google map作底图
    USB、UART、SPI等总线速率
    步步详解之第1节----ALTERA FPGA关于PLL的使用,帮你用光所有PLL
    FPGA笔试必会知识点2—FPGA器件
    FPGA笔试必会知识点1--数字电路基本知识
    (转)modelsim-win64-10.1c的安装
    基于FPGA的温度采集显示与报警
    基于FPGA的步进电机说明文档
    基于FPGA的直流电机
  • 原文地址:https://www.cnblogs.com/zle1992/p/10608376.html
Copyright © 2011-2022 走看看