zoukankan      html  css  js  c++  java
  • 对tensorflow 中的attention encoder-decoder模型调试分析

      1 #-*-coding:utf8-*-
      2 
      3 __author = "buyizhiyou"
      4 __date = "2017-11-21"
      5 
      6 
      7 import random, time, os, decoder
      8 from PIL import Image
      9 import numpy as np
     10 import tensorflow as tf
     11 import pdb
     12 import decoder
     13 import random
     14 
     15 '''
     16 在汉字ocr项目中,利用基于attention的encoder-decoder(seq2seq)模型进行端对端的训练
     17 单步调试,追踪tensorflow 对 attention-seq2seq模型的实现方式
     18 python 中seq2seq.py的接口:tf.nn.seq2seq.embedding_attention_seq2seq()
     19 把用到的部分取出来单独调试
     20 '''
     21 
     22 batch_size = 16
     23 dec_seq_len = 8#图片对应的汉字数8
     24 enc_lstm_dim = 256
     25 dec_lstm_dim = 512
     26 vocab_size = 1002
     27 embedding_size = 100
     28 lr = 0.01
     29 global_step = tf.Variable(0)
     30 
     31 cnn = tf.truncated_normal([16,10,35,64],mean=0,stddev=1.0,dtype=tf.float32)#模拟初始化一个cnn提取特征后的图片
     32 #(batch_size,height,width,channels)(16, 10, 35, 64)
     33 true_labels = []
     34 #随即生成batch中图片对应的序列,无需embedding
     35 for i in range(batch_size):
     36     seq_label = []
     37     for j in range(dec_seq_len):
     38         seq_label.append(random.randint(0,1000))
     39     true_labels.append(seq_label)
     40 
     41 
     42 #编码
     43 def encoder(inp):#inp:shape=(16, 35, 64)
     44     #pdb.set_trace()
     45     enc_init_shape = [batch_size, enc_lstm_dim]#[16,256]
     46     with tf.variable_scope('encoder_rnn'):
     47         with tf.variable_scope('forward'):
     48             lstm_cell_fw = tf.nn.rnn_cell.LSTMCell(enc_lstm_dim)
     49             init_fw = tf.nn.rnn_cell.LSTMStateTuple(
     50                                 tf.get_variable("enc_fw_c", enc_init_shape),
     51                                 tf.get_variable("enc_fw_h", enc_init_shape)
     52                                 )
     53         with tf.variable_scope('backward'):
     54             lstm_cell_bw = tf.nn.rnn_cell.LSTMCell(enc_lstm_dim)
     55             init_bw = tf.nn.rnn_cell.LSTMStateTuple(
     56                                 tf.get_variable("enc_bw_c", enc_init_shape),
     57                                 tf.get_variable("enc_bw_h", enc_init_shape)
     58                                 )
     59         output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_cell_fw, 
     60                                                     lstm_cell_bw, 
     61                                                     inp, 
     62                                                     sequence_length = tf.fill([batch_size],
     63                                                     tf.shape(inp)[1]), #(35,35,35...,35,35,35)
     64                                                     initial_state_fw = init_fw, 
     65                                                     initial_state_bw = init_bw 
     66                                                     )#shape=(16, 35, 256)
     67     return tf.concat(2,output)##shape=(16, 35, 512)
     68 
     69 encoder = tf.make_template('fun', encoder)
     70 # shape is (batch size, rows, columns, features)
     71 # swap axes so rows are first. map splits tensor on first axis, so encoder will be applied to tensors
     72 # of shape (batch_size,time_steps,feat_size)
     73 rows_first = tf.transpose(cnn,[1,0,2,3])#shape=(10, 16, 35, 64)
     74 res = tf.map_fn(encoder, rows_first, dtype=tf.float32)#shape=(10, 16, 35, 512)
     75 encoder_output = tf.transpose(res,[1,0,2,3])#shape=(16, 10, 35, 512)
     76 
     77 dec_lstm_cell = tf.nn.rnn_cell.LSTMCell(dec_lstm_dim)
     78 dec_init_shape = [batch_size, dec_lstm_dim]
     79 dec_init_state = tf.nn.rnn_cell.LSTMStateTuple( tf.truncated_normal(dec_init_shape),
     80                                                 tf.truncated_normal(dec_init_shape) )
     81 
     82 init_words = np.zeros([batch_size,1,vocab_size])#(16, 1, 1002)
     83 
     84 
     85 #pdb.set_trace()
     86 (output,state) = decoder.embedding_attention_decoder(dec_init_state,#[16, 512]第一个解码cell的state=[c,h]
     87                                                     tf.reshape(encoder_output,[batch_size, -1,2*enc_lstm_dim]),
     88                                                     #encoder输出reshape为 attention states作为attention模块的输入 shape=(16,350,512)
     89                                                     dec_lstm_cell,#lstm单元,作为解码层
     90                                                     vocab_size,#1002
     91                                                     dec_seq_len,#8
     92                                                     batch_size,#16
     93                                                     embedding_size,#100
     94                                                     feed_previous=True)#dec_seq_len = num_words = time_steps
     95 pdb.set_trace()
     96 cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(output,true_labels))
     97 learning_rate = tf.train.exponential_decay(lr, global_step, 50, 0.9)
     98 train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy,global_step=global_step)
     99 correct_prediction = tf.equal(tf.to_int32(tf.argmax( output, 2)), true_labels)
    100                                                

    decode.py

      1 #-*-coding:utf8-*-
      2 
      3 
      4 """
      5 截取自tensorflow seq2seq.py 文件
      6 """
      7 import numpy as np
      8 import tensorflow as tf
      9 import pdb
     10 from tensorflow.python import shape
     11 from tensorflow.python.framework import dtypes
     12 from tensorflow.python.framework import ops
     13 from tensorflow.python.ops import array_ops
     14 from tensorflow.python.ops import control_flow_ops
     15 from tensorflow.python.ops import embedding_ops
     16 from tensorflow.python.ops import math_ops
     17 from tensorflow.python.ops import nn_ops
     18 from tensorflow.python.ops import rnn
     19 from tensorflow.python.ops import rnn_cell
     20 from tensorflow.python.ops import variable_scope
     21 from tensorflow.python.util import nest
     22 
     23 linear = rnn_cell._linear    # pylint: disable=protected-access
     24 
     25 def attention_decoder(initial_state,#(16, 512)
     26                       attention_states,#shape=(16, 350, 512)
     27                       cell,
     28                       vocab_size,#1002
     29                       time_steps,#num_words,8
     30                       batch_size,#16
     31                       output_size=None,#512
     32                       loop_function=None,
     33                       dtype=None,
     34                       scope=None):
     35     pdb.set_trace()
     36     if attention_states.get_shape()[2].value is None:#tf 张量 get_shape()方法获取size
     37         raise ValueError("Shape[2] of attention_states must be known: %s"
     38                                          % attention_states.get_shape())
     39     if output_size is None:
     40         output_size = cell.output_size#512
     41 
     42     with variable_scope.variable_scope(scope or "attention_decoder", dtype=dtype) as scope:
     43         dtype = scope.dtype
     44 
     45         attn_length = attention_states.get_shape()[1].value #350
     46         if attn_length is None:
     47             attn_length = shape(attention_states)[1]
     48         attn_size = attention_states.get_shape()[2].value#512
     49 
     50         # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
     51         hidden = array_ops.reshape(attention_states, [-1, attn_length, 1, attn_size])#shape=(16, 350, 1, 512) 
     52         attention_vec_size = attn_size    # Size of query vectors for attention.   512
     53         k = variable_scope.get_variable("AttnW",[1, 1, attn_size, attention_vec_size])#shape=(1,1,512,512)
     54         hidden_features = nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")#(16 ,350, 1, 512) w_1*h_j
     55         v = variable_scope.get_variable("AttnV", [attention_vec_size])
     56 
     57 
     58         def attention(query):
     59             #LSTMStateTuple(c= shape=(16, 512) dtype=float32>, h=< shape=16, 512) dtype=float32>)
     60             """Put attention masks on hidden using hidden_features and query."""
     61             if nest.is_sequence(query):    # If the query is a tuple, flatten it.
     62                 query_list = nest.flatten(query) #[c,h],第一个随即初始化,以后调用之前计算的
     63                 for q in query_list:    # Check that ndims == 2 if specified.
     64                     ndims = q.get_shape().ndims
     65                     if ndims:
     66                         assert ndims == 2
     67                 query = array_ops.concat(1, query_list)# shape=(16, 1024)
     68             with variable_scope.variable_scope("Attention_0"):
     69                 y = linear(query, attention_vec_size, True)# shape=(16, 512) w_2*s_t
     70                 y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # shape=(16, 1, 1, 512)
     71                 s = math_ops.reduce_sum(
     72                         v * math_ops.tanh(hidden_features + y), [2, 3])  #!!!!!!!!!!!公式(3)shape=(16, 350)
     73                 a = nn_ops.softmax(s)#  公式(2)shape=(16, 350)
     74                 # Now calculate the attention-weighted vector d.
     75                 d = math_ops.reduce_sum(
     76                         array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,#公式(1)
     77                         [1, 2])#shape=(16, 512) 
     78                 ds = array_ops.reshape(d, [-1, attn_size])#shape=(16, 512) #!!!!!!!!!!!!以上是attention model中三个关键公式的实现
     79             return ds
     80         #pdb.set_trace()
     81         prev = array_ops.zeros([batch_size,output_size])# shape=(16, 512) cell层第一个cell启动计算所需输入,
     82                                                         #随机初始化,以后的cell调用之前的计算结果
     83         batch_attn_size = array_ops.pack([batch_size, attn_size]) #(2,?)
     84         attn = array_ops.zeros(batch_attn_size, dtype=dtype)#shape=(16, 512)
     85         attn.set_shape([None, attn_size])#(16,512)
     86 
     87         def cond(time_step, prev_o_t, prev_softmax_input, state_c, state_h, outputs2):
     88             return time_step < time_steps
     89 
     90         def body(time_step, prev_o_t, prev_softmax_input, state_c, state_h, outputs2):#prev_o_t=prev:shape=(16,512) 
     91                                                 #outputs:shape=(16, ?, 1002) prev_softmax_input=init_word:shape=(16, 1002)
     92             state = tf.nn.rnn_cell.LSTMStateTuple(state_c,state_h)#第一次随机初始状态,之后调用之前的
     93             pdb.set_trace()
     94             with variable_scope.variable_scope("loop_function", reuse=True):
     95                 inp = loop_function(prev_softmax_input, time_step)#shape=(16,100) inp用来做什么 作为每个cell单元从下而
     96                 #来的输入??而prev_o_t则为从左而来的输入??而且Inp和上一个cell单元的softmax_input(最终进softmax之前的cell输出)有关(prev_softmax_input)
     97 
     98             input_size = inp.get_shape().with_rank(2)[1]#100
     99             if input_size.value is None:
    100                 raise ValueError("Could not infer input size from input: %s" % inp.name)
    101             x = tf.concat(1,[inp,prev_o_t])#shape=(16, 612)  这个地方inp ,prev_o_t = loop_function(softmax_output),output
    102             # Run the RNN.
    103             cell_output, state = cell(x, state)#decoder层512个lstm单元 cell_output:shape=(16, 512) state:shape=(16, 512)
    104             # Run the attention mechanism.
    105             attn = attention(state)#shape=(16, 512) attenion模块的输出,C_i
    106 
    107             with variable_scope.variable_scope("AttnOutputProjection"):
    108                 output = math_ops.tanh(linear([cell_output, attn], output_size, False))#shape=(16, 512) y_i = f(C_i,S_i)
    109                 with variable_scope.variable_scope("FinalSoftmax"):
    110                     softmax_input = linear(output,vocab_size,False)#shape=(16, 1002) #decoder层后加一层softmax??作为softmax_input
    111 
    112             new_outputs = tf.concat(1, [outputs2,tf.expand_dims(softmax_input,1)])#shape=(16, ?, 1002)[,...y_t-1,y_t,...]
    113             return (time_step + tf.constant(1, dtype=tf.int32),
    114                             output, softmax_input, state.c, state.h, new_outputs)#既是输出,又是下一轮的输入
    115 
    116         time_step = tf.constant(0, dtype=tf.int32)
    117         shape_invariants = [time_step.get_shape(),
    118                             prev.get_shape(),
    119                             tf.TensorShape([batch_size, vocab_size]),
    120                             tf.TensorShape([batch_size,512]),
    121                             tf.TensorShape([batch_size,512]),
    122                             tf.TensorShape([batch_size, None, vocab_size])]
    123 
    124 # START keyword is 0
    125         init_word = np.zeros([batch_size, vocab_size])#shape=(16,1002)
    126 
    127         loop_vars = [time_step,
    128                      prev,
    129                      tf.constant(init_word, dtype=tf.float32),
    130                      initial_state.c,initial_state.h,
    131                      tf.zeros([batch_size,1,vocab_size])] 
    136 
    137         outputs = tf.while_loop(cond, body, loop_vars, shape_invariants)##shape=(16, ?, 1002)
    138         '''
    139         loop_vars = [...]
    140         while cond(*loop_vars):
    141             loop_vars = body(*loop_vars)   
    142         '''
    143 
    144     return outputs[-1][:,1:], tf.nn.rnn_cell.LSTMStateTuple(outputs[-3],outputs[-2])
    145 
    146 def embedding_attention_decoder(initial_state,#shape=(16, 512)
    147                                 attention_states,# shape=(16, 350, 512)
    148                                 cell,#定义的lstm单元
    149                                 num_symbols,#1002
    150                                 time_steps,
    151                                 batch_size,#16
    152                                 embedding_size,#100
    153                                 output_size=None,#512
    154                                 output_projection=None,
    155                                 feed_previous=False,#True
    156                                 update_embedding_for_previous=True,
    157                                 dtype=None,
    158                                 scope=None):
    159     if output_size is None:
    160         output_size = cell.output_size#512
    161     if output_projection is not None:
    162         proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
    163         proj_biases.get_shape().assert_is_compatible_with([num_symbols])
    164 
    165     with variable_scope.variable_scope(scope or "embedding_attention_decoder", dtype=dtype) as scope:
    166         embedding = variable_scope.get_variable("embedding",[num_symbols, embedding_size])
    167         loop_function = tf.nn.seq2seq._extract_argmax_and_embed(embedding, 
    168                           output_projection,update_embedding_for_previous) if feed_previous else None
    169                         #(16,1002)==>(16,100)找argmax,然后embedding
    170         return attention_decoder(
    171                 initial_state,
    172                 attention_states,
    173                 cell,
    174                 num_symbols,#1002
    175                 time_steps,#8
    176                 batch_size,
    177                 output_size=output_size,#512
    178                 loop_function=loop_function)

     

    关于embedding接口:

    测试如下:

     1 #-*-coding:utf8-*-
     2 
     3 __author = "buyizhiyou"
     4 __date = "2017-11-21"
     5 
     6 import tensorflow as tf
     7 import numpy as np
     8 
     9 '''
    10 测试embedding接口
    11 '''
    12 embedding = tf.Variable(np.identity(5,dtype=np.int32))
    13 inputs = tf.placeholder(dtype=tf.int32,shape=[None])
    14 input_embedding = tf.nn.embedding_lookup(embedding,inputs)
    15 
    16 with tf.Session() as sess:
    17     sess.run(tf.global_variables_initializer())
    18     print(sess.run(embedding))
    19 '''
    20 [[1 0 0 0 0]
    21  [0 1 0 0 0]
    22  [0 0 1 0 0]
    23  [0 0 0 1 0]
    24  [0 0 0 0 1]]
    25 '''
    26     print(sess.run(input_embedding,feed_dict={inputs:[1,2,3,0,3,2,1]}))
    27 '''
    28 [[0 1 0 0 0]
    29  [0 0 1 0 0]
    30  [0 0 0 1 0]
    31  [1 0 0 0 0]
    32  [0 0 0 1 0]
    33  [0 0 1 0 0]
    34  [0 1 0 0 0]]
    35 '''
  • 相关阅读:
    Scrum会议5
    小组项目alpha发布的评价
    第二阶段冲刺记录三
    第二阶段冲刺记录二
    第13周学习进度
    第二阶段冲刺记录1
    《人月神话》阅读笔记01
    第12周学习进度
    意见汇总
    双人结对,四则运算(三阶段)
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/7874223.html
Copyright © 2011-2022 走看看