zoukankan      html  css  js  c++  java
  • tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

    tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

    https://blog.csdn.net/u014365862/article/details/78238807

    MachineLP的Github(欢迎follow):https://github.com/MachineLP

    我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一个框架,包含模型有:vgg(vgg16,vgg19), resnet(resnet_v2_50,resnet_v2_101,resnet_v2_152), inception_v4, inception_resnet_v2等。

    1.  
      chunk_size = 256
    2.  
      chunk_n = 160
    3.  
      rnn_size = 256
    4.  
      num_layers = 2
    5.  
      n_output_layer = MAX_CAPTCHA*CHAR_SET_LEN # 输出层

    单层rnn:

    tf.contrib.rnn.static_rnn:

    输入:[步长,batch,input] 

    输出:[n_steps,batch,n_hidden] 

    还有rnn中加dropout

    1.  
      def recurrent_neural_network(data):
    2.  
       
    3.  
      data = tf.reshape(data, [-1, chunk_n, chunk_size])
    4.  
      data = tf.transpose(data, [1,0,2])
    5.  
      data = tf.reshape(data, [-1, chunk_size])
    6.  
      data = tf.split(data,chunk_n)
    7.  
       
    8.  
      # 只用RNN
    9.  
      layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
    10.  
      lstm_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size)
    11.  
      outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)
    12.  
      # outputs = tf.transpose(outputs, [1,0,2])
    13.  
      # outputs = tf.reshape(outputs, [-1, chunk_n*rnn_size])
    14.  
      ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
    15.  
       
    16.  
      return ouput

    多层rnn:

    tf.nn.dynamic_rnn:

    输入:[batch,步长,input] 
    输出:[batch,n_steps,n_hidden] 
    所以我们需要tf.transpose(outputs, [1, 0, 2]),这样就可以取到最后一步的output

    1.  
      def recurrent_neural_network(data):
    2.  
      # [batch,chunk_n,input]
    3.  
      data = tf.reshape(data, [-1, chunk_n, chunk_size])
    4.  
      #data = tf.transpose(data, [1,0,2])
    5.  
      #data = tf.reshape(data, [-1, chunk_size])
    6.  
      #data = tf.split(data,chunk_n)
    7.  
       
    8.  
      # 只用RNN
    9.  
      layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
    10.  
      #1
    11.  
      # lstm_cell1 = tf.contrib.rnn.BasicLSTMCell(rnn_size)
    12.  
      # outputs1, status1 = tf.contrib.rnn.static_rnn(lstm_cell1, data, dtype=tf.float32)
    13.  
       
    14.  
      def lstm_cell():
    15.  
      return tf.contrib.rnn.LSTMCell(rnn_size)
    16.  
      def attn_cell():
    17.  
      return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=keep_prob)
    18.  
      # stack = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(0, num_layers)], state_is_tuple=True)
    19.  
      stack = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(0, num_layers)], state_is_tuple=True)
    20.  
      # outputs, _ = tf.nn.dynamic_rnn(stack, data, seq_len, dtype=tf.float32)
    21.  
      outputs, _ = tf.nn.dynamic_rnn(stack, data, dtype=tf.float32)
    22.  
      # [batch,chunk_n,rnn_size] -> [chunk_n,batch,rnn_size]
    23.  
      outputs = tf.transpose(outputs, (1, 0, 2))
    24.  
       
    25.  
      ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
    26.  
       
    27.  
      return ouput




  • 相关阅读:
    HRESULT:0x80070057 (E_INVALIDARG)的异常的解决方案
    c# 取两个时间的间隔
    [转]C#算法
    智能仓库管理系统方案(四)
    分页存储过程
    ASP.NET2.0_多语言本地化应用程序(转)
    C#绘图双缓冲技术总结(转)
    C#.net同步异步SOCKET通讯和多线程总结(转)
    WIN2003 sp2中Delphi 7中的Project菜单中Options菜单打不开
    C#关于日期月天数和一年有多少周及某年某周时间段的计算
  • 原文地址:https://www.cnblogs.com/DjangoBlog/p/9542571.html
Copyright © 2011-2022 走看看