zoukankan      html  css  js  c++  java
  • Tensorflow RNN中的坑

    由于多个版本的积累,Tensorflow中的RNN比较杂乱,到底哪个版本实际效率高,经过实测我发现和api中说明的并不一致,在此记录一下。

    注意,以下相关代码和结论均运行在tensorflow==1.10上

    1.脉络梳理

    在1.10版本的tensorflow中,有关rnn的部分一般在以下四个包中,

    • tf.contrib.rnn
    • tf.contrib.cudnn_rnn
    • tf.nn.rnn_cell
    • tf.compat.v1.nn.rnn_cell

    其中tf.nn.rnn_cell、tf.compat.v1.nn.rnn_cell和tf.contrib.rnn互相等价,那么简化为两部分:

    • tf.contrib.rnn
    • tf.contrib.cudnn_rnn

    在此我们只考察常用的RNN cell(例如 lstm、gru等),并把等价的做个性能对比,上面包中能找到的分类如下:

    LSTM:

      tf.contrib.rnn.LSTMCell

      tf.contrib.rnn.LSTMBlockCell

      tf.contrib.rnn.LSTMBlockFusedCell

      tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell

    GRU:

      tf.contrib.rnn.GRUCell

      tf.contrib.rnn.GRUBlockCellV2

      tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell

    SRU:

      tf.contrib.rnn.SRUCell

     

    2.评测代码

    代码如下,除了简单的session.run计时外,还做了利用timeline工具profiling。但在后续分析时,主要还是利用计时结果。因为timeline的分析结果里面细化到每个最小操作符,比较麻烦。

    因为输入数据是使用numpy随机生成的,直接copy过去就能运行, 代码中比较了几种情况,分别是lstm,bi-lstm,2 layer bi-lstm。

     1 # -*- coding: utf-8 -*-
     2 import tensorflow as tf
     3 import numpy as np
     4 import time
     5 import functools
     6 import json
     7 from tensorflow.python.client import timeline
     8 
     9 
    10 def update_timeline(chrome_trace, _timeline_dict):
    11     # convert crome trace to python dict
    12     chrome_trace_dict = json.loads(chrome_trace)
    13     # for first run store full trace
    14     if _timeline_dict is None:
    15         _timeline_dict = chrome_trace_dict
    16     # for other - update only time consumption, not definitions
    17     else:
    18         for event in chrome_trace_dict['traceEvents']:
    19             # events time consumption started with 'ts' prefix
    20             if 'ts' in event:
    21                 _timeline_dict['traceEvents'].append(event)
    22     return _timeline_dict
    23 
    24 
    25 batch_size = 1
    26 time_step = 70
    27 hidden_num = 512
    28 stack_num = 1
    29 
    30 # 设置随机种子,保持每次随机数一致
    31 np.random.seed(0)
    32 # 创建正态分布输入数据,batch_size*time_step*hidden_num
    33 np_input_data = np.random.randn(batch_size, time_step, hidden_num).astype(np.float32)
    34 np_input_len = [time_step]*batch_size
    35 # print np_input_data
    36 
    37 # LSTM cells
    38 #rnn_cell = tf.contrib.rnn.LSTMCell  # child of RNNCell
    39 rnn_cell = tf.contrib.rnn.LSTMBlockCell  # child of RNNCell
    40 # rnn_cell = tf.contrib.rnn.LSTMBlockFusedCell  # not child of RNNCell
    41 # rnn_cell = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell  # child of RNNCell
    42 
    43 # GRU cells
    44 # rnn_cell = tf.contrib.rnn.GRUCell
    45 # rnn_cell = tf.contrib.rnn.GRUBlockCellV2
    46 # rnn_cell = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell
    47 
    48 # SRU cells
    49 # rnn_cell = tf.contrib.rnn.SRUCell
    50 
    51 # 构建一个简单的双层双向lstm网络
    52 input_data = tf.placeholder(dtype=tf.float32, shape=[batch_size, time_step, hidden_num], name='input_data')
    53 trans_data = tf.transpose(input_data, [1, 0, 2])
    54 
    55 outputs = [trans_data]
    56 for i in range(stack_num):
    57     fw_rnn = rnn_cell(hidden_num,name='fw_cell_%d' % i)
    58     #bw_rnn = rnn_cell(hidden_num,name='bw_cell_%d' % i)
    59     if rnn_cell is not tf.contrib.rnn.LSTMBlockFusedCell:
    60         fw_rnn = tf.contrib.rnn.FusedRNNCellAdaptor(fw_rnn, use_dynamic_rnn=False)
    61         #bw_rnn = tf.contrib.rnn.FusedRNNCellAdaptor(bw_rnn, use_dynamic_rnn=True)
    62     #bw_rnn = tf.contrib.rnn.TimeReversedFusedRNN(bw_rnn)
    63     outputs1, state1 = fw_rnn(outputs[-1], sequence_length=np_input_len, dtype=tf.float32)
    64     #outputs2, state2 = bw_rnn(outputs[-1], sequence_length=np_input_len, dtype=tf.float32)
    65     #next_layer_input = tf.concat([outputs1, outputs2], axis=-1)
    66     #outputs.append(next_layer_input)
    67 
    68     outputs.append(outputs1)
    69 
    70 total_time = 0
    71 _timeline_dict = None
    72 runs = 1000
    73 sess = tf.InteractiveSession()
    74 sess.run(tf.global_variables_initializer())
    75 for i in range(0, runs):
    76     # options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    77     run_metadata = tf.RunMetadata()
    78     t1 = time.time()
    79     # result = sess.run([outputs[-1]], feed_dict={input_data: np_input_data}, options=options,
    80     #                   run_metadata=run_metadata)
    81     result = sess.run([outputs[-1]], feed_dict={input_data: np_input_data},
    82                       run_metadata=run_metadata)
    83     t2 = time.time()
    84     total_time += (t2 - t1)
    85     # fetched_timeline = timeline.Timeline(run_metadata.step_stats)
    86     # chrome_trace = fetched_timeline.generate_chrome_trace_format()
    87     #print t2 - t1
    88     # _timeline_dict = update_timeline(chrome_trace, _timeline_dict)
    89 
    90 print rnn_cell
    91 print 'average time %f ms' % (total_time / float(runs)*1000.0)
    92 # with open('fused_%d_runs.json' % runs, 'w') as f:
    93 #     json.dump(_timeline_dict, f)

    3.性能分析

    在V100上测试,得到的结果如下

     

     其中可以看出,无论是cudnn优化版本,还是原始版本,在gpu上的表现都不尽如人意,反倒是LSTMBlockFusedCell这个实现,效果出奇的好,甚至在GPU上好于GRU和SRU,CPU上差异也不大。至于造成这些现象的原因,需要阅读这几种实现的底层源码,因为时间精力有限,我没有仔细阅读,如果有感兴趣的读者,可以深入研究。总结如下:

    1.如果使用tf1中的rnn,建议使用LSTMBlockFusedCell

    2.所有RNN相关实现在GPU上性能,基本和层数、方向是线性关系

    3.所有RNN相关实现在CPU上性能,和层数、方向不一定是线性关系

    4.GRU和SRU的实现在CPU上比GPU上效率高

    5.tf1中的RNN坑很多,要谨慎使用

     

  • 相关阅读:
    java概述------
    java中有几种方法实现一个线程?用什么关键字修饰同步方法?stop()和suspend()方法为何不推荐使用?
    java的5个框架,哪个框架更适合你的项目?
    java的热门应用有哪些?
    vue-router 切换页面时怎么设置过渡动画
    Referrer Policy 介绍
    await进行同步操作
    vue中axios拦截器同一项目多域名如何配置
    正规方程求解特征参数的推导过程
    一种网页中显示代码所涉及的字符转义问题的解决方案
  • 原文地址:https://www.cnblogs.com/hrlnw/p/10748990.html
Copyright © 2011-2022 走看看