zoukankan      html  css  js  c++  java
  • tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

    tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

    https://blog.csdn.net/u014365862/article/details/78238807

    MachineLP的Github(欢迎follow):https://github.com/MachineLP

    我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一个框架,包含模型有:vgg(vgg16,vgg19), resnet(resnet_v2_50,resnet_v2_101,resnet_v2_152), inception_v4, inception_resnet_v2等。

    1.  
      chunk_size = 256
    2.  
      chunk_n = 160
    3.  
      rnn_size = 256
    4.  
      num_layers = 2
    5.  
      n_output_layer = MAX_CAPTCHA*CHAR_SET_LEN # 输出层

    单层rnn:

    tf.contrib.rnn.static_rnn:

    输入:[步长,batch,input] 

    输出:[n_steps,batch,n_hidden] 

    还有rnn中加dropout

    1.  
      def recurrent_neural_network(data):
    2.  
       
    3.  
      data = tf.reshape(data, [-1, chunk_n, chunk_size])
    4.  
      data = tf.transpose(data, [1,0,2])
    5.  
      data = tf.reshape(data, [-1, chunk_size])
    6.  
      data = tf.split(data,chunk_n)
    7.  
       
    8.  
      # 只用RNN
    9.  
      layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
    10.  
      lstm_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size)
    11.  
      outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)
    12.  
      # outputs = tf.transpose(outputs, [1,0,2])
    13.  
      # outputs = tf.reshape(outputs, [-1, chunk_n*rnn_size])
    14.  
      ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
    15.  
       
    16.  
      return ouput

    多层rnn:

    tf.nn.dynamic_rnn:

    输入:[batch,步长,input] 
    输出:[batch,n_steps,n_hidden] 
    所以我们需要tf.transpose(outputs, [1, 0, 2]),这样就可以取到最后一步的output

    1.  
      def recurrent_neural_network(data):
    2.  
      # [batch,chunk_n,input]
    3.  
      data = tf.reshape(data, [-1, chunk_n, chunk_size])
    4.  
      #data = tf.transpose(data, [1,0,2])
    5.  
      #data = tf.reshape(data, [-1, chunk_size])
    6.  
      #data = tf.split(data,chunk_n)
    7.  
       
    8.  
      # 只用RNN
    9.  
      layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
    10.  
      #1
    11.  
      # lstm_cell1 = tf.contrib.rnn.BasicLSTMCell(rnn_size)
    12.  
      # outputs1, status1 = tf.contrib.rnn.static_rnn(lstm_cell1, data, dtype=tf.float32)
    13.  
       
    14.  
      def lstm_cell():
    15.  
      return tf.contrib.rnn.LSTMCell(rnn_size)
    16.  
      def attn_cell():
    17.  
      return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=keep_prob)
    18.  
      # stack = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(0, num_layers)], state_is_tuple=True)
    19.  
      stack = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(0, num_layers)], state_is_tuple=True)
    20.  
      # outputs, _ = tf.nn.dynamic_rnn(stack, data, seq_len, dtype=tf.float32)
    21.  
      outputs, _ = tf.nn.dynamic_rnn(stack, data, dtype=tf.float32)
    22.  
      # [batch,chunk_n,rnn_size] -> [chunk_n,batch,rnn_size]
    23.  
      outputs = tf.transpose(outputs, (1, 0, 2))
    24.  
       
    25.  
      ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
    26.  
       
    27.  
      return ouput




  • 相关阅读:
    Java垃圾回收
    Android Starting Window(Preview Window)
    JVM虚拟机结构
    表驱动法 -《代码大全》读书笔记
    快速Android开发系列网络篇之Retrofit
    快速Android开发系列网络篇之Volley
    快速Android开发系列网络篇之Android-Async-Http
    清除Android工程中没用到的资源
    快速Android开发系列通信篇之EventBus
    Android点击列表后弹出输入框,所点击项自动滚动到输入框上方
  • 原文地址:https://www.cnblogs.com/DjangoBlog/p/9542571.html
Copyright © 2011-2022 走看看