zoukankan      html  css  js  c++  java
  • 【转载】 LSTM构建步骤以及static_rnn与dynamic_rnn之间的区别

    原文地址:

    https://blog.csdn.net/qq_23981335/article/details/89097757

    ---------------------
    作者:周卫林
    来源:CSDN

    -----------------------------------------------------------------------------------------------

    1.构建LSTM
    在tensorflow中,存在两个库函数可以构建LSTM,分别为tf.nn.rnn_cell.BasicLSTMCell和tf.contrib.rnn.BasicLSTMCell,最常使用的参数是num_units,表示的是LSTM中隐含状态的维度,state_in_tuple表示将(c,h)表示为一个元组。

     

    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_size)

    2.初始化隐含状态 
    LSTM的输入不仅有数据输入,还有前一个时刻的状态输入,因此需要初始化输入状态

    initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)

    3.添加dropout层 
    可以在基本的LSTM上添加dropout层

    lstm_cell =  tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob)

    4.多层LSTM

    cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*hidden_layer_num)

    其中hidden_layer_num为LSTM的层数

    5.完整代码

    (1)原理表达最清楚、最一目了然的LSTM构建方式如下:

    import tensorflow as tf
    import numpy as np
     
    batch_size=2
    hidden_size=64
    num_steps=10
    input_dim=8
     
    input=np.random.randn(batch_size,num_steps,input_dim)
    input[1,6:]=0
    x=tf.placeholder(dtype=tf.float32,shape=[batch_size,num_steps,input_dim],name='input_x')
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_size)
    initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
     
    outputs=[]
    with tf.variable_scope('RNN'):
        for i in range(num_steps):
            if i > 0 :
                # print(tf.get_variable_scope())
                tf.get_variable_scope().reuse_variables()
     
            output=lstm_cell(x[:,i,:],initial_state)
            outputs.append(output)
     
    with tf.Session() as sess:
        init_op=tf.initialize_all_variables()
        sess.run(init_op)
     
        np.set_printoptions(threshold=np.NAN)
     
        result=sess.run(outputs,feed_dict={x:input})
        print(result)

    (2)简化构建形式

    如果觉得写for循环比较麻烦,则可以使用tf.nn.static_rnn函数,这个函数就是使用for循环实现的LSTM ,但是需要注意的是该函数的参数设置:

    tf.nn.static_rnn(
        cell,
        inputs,
        initial_state=None,
        dtype=None,
        sequence_length=None,
        scope=None
    )

    其中cell即为LSTM,inputs的维度必须为  [ num_steps,  batch_size,  input_dim ]  ,sequence_length为batch_size个输入的长度。

          完整代码如下:

    import tensorflow as tf
    import numpy as np
     
    batch_size=2
    num_units=64
    num_steps=10
    input_dim=8
     
    input=np.random.randn(batch_size,num_steps,input_dim)
    input[1,6:]=0
    x=tf.placeholder(dtype=tf.float32,shape=[batch_size,num_steps,input_dim],name='input_x')
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units)
    initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
    y=tf.unstack(x,axis=1)
    # x:[batch_size,num_steps,input_dim],type:placeholder
    # y:[num_steps,batch_size,input_dim],type:list
    output,state=tf.nn.static_rnn(lstm_cell,y,sequence_length=[10,6],initial_state=initial_state)
    with tf.Session() as sess:
        init_op=tf.initialize_all_variables()
        sess.run(init_op)
     
        np.set_printoptions(threshold=np.NAN)
     
        result1,result2=(sess.run([output,state],feed_dict={x:input}))
        result1=np.asarray(result1)
        result2=np.asarray(result2)
        print(result1)
        print('*'*100)
        print(result2)

         还可以使用tf.nn.dynamic_rnn函数来实现

    tf.nn.dynamic_rnn(
        cell,
        inputs,
        sequence_length=None,
        initial_state=None,
        dtype=None,
        parallel_iterations=None,
        swap_memory=False,
        time_major=False,
        scope=None
    )

    该函数的cell即为LSTM,inputs的维度是    [batch_size,num_steps,input_dim]

    output,state=tf.nn.dynamic_rnn(cell,x,sequence_length=[10,6],initial_state=initial_state)

    6、static_rnn与dynamic_rnn之间的区别
            不论dynamic_rnn还是static_rnn,每个batch的序列长度都是一样的(不足的话自己要去padding),不同的是dynamic会根据 sequence_length 中止计算。另外一个不同是dynamic_rnn动态生成graph 。
    但是dynamic_rnn不同的batch序列长度可以不一样,例如第一个batch长度为10,第二个batch长度为20,但是static_rnn不同的batch序列长度必须是相同的,都必须是num_steps

            下面使用dynamic_rnn来实现不同batch之间的序列长度不同:

    import tensorflow as tf
    import numpy as np
     
    batch_size=2
    num_units=64
    num_steps=10
    input_dim=8
     
    input=np.random.randn(batch_size,num_steps,input_dim)
    input2=np.random.randn(batch_size,num_steps*2,input_dim)
     
    x=tf.placeholder(dtype=tf.float32,shape=[batch_size,None,input_dim],name='input') # None 表示序列长度不定
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units)
    initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
     
     
    output,state=tf.nn.dynamic_rnn(lstm_cell,x,initial_state=initial_state)
     
    with tf.Session() as sess:
        init_op=tf.initialize_all_variables()
        sess.run(init_op)
     
        np.set_printoptions(threshold=np.NAN)
     
        result1,result2=(sess.run([output,state],feed_dict={x:input})) # 序列长度为10 x:[batch_size,num_steps,input_dim],此时LSTM个数为10个,或者说循环10次LSTM
        result1=np.asarray(result1)
        result2=np.asarray(result2)
        print(result1)
        print('*'*100)
        print(result2)
     
        result1, result2 = (sess.run([output, state], feed_dict={x:input2})) # 序列长度为20 x:[batch_size,num_steps,input_dim],此时LSTM个数为20个,或者说循环20次LSTM
        result1 = np.asarray(result1)
        result2 = np.asarray(result2)
        print(result1)
        print('*' * 100)
        print(result2)

    但是static_rnn是不可以的。

    7.dynamic_rnn的性能和static_rnn的性能差异

    import tensorflow as tf
    import numpy as np
    import time
     
    num_step=100
    input_dim=8
    batch_size=2
    num_unit=64
     
    input_data=np.random.randn(batch_size,num_step,input_dim)
    x=tf.placeholder(dtype=tf.float32,shape=[batch_size,num_step,input_dim])
    seq_len=tf.placeholder(dtype=tf.int32,shape=[batch_size])
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_unit)
    initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
     
    y=tf.unstack(x,axis=1)
    output1,state1=tf.nn.static_rnn(lstm_cell,y,sequence_length=seq_len,initial_state=initial_state)
     
    output2,state2=tf.nn.dynamic_rnn(lstm_cell,x,sequence_length=seq_len,initial_state=initial_state)
     
    print('begin train...')
    with tf.Session() as sess:
        init_op=tf.initialize_all_variables()
        sess.run(init_op)
     
        for i in range(100):
            sess.run([output1,state1],feed_dict={x:input_data,seq_len:[10]*batch_size})
     
        time1=time.time()
        for i in range(100):
            sess.run([output1,state1],feed_dict={x:input_data,seq_len:[10]*batch_size})
        time2=time.time()
        print('static_rnn seq_len:10		{}'.format(time2-time1))
     
     
        for i in range(100):
            sess.run([output1,state1],feed_dict={x:input_data,seq_len:[100]*batch_size})
        time3=time.time()
        print('static_rnn seq_len:100		{}'.format(time3-time2))
     
     
     
        for i in range(100):
            sess.run([output2,state2],feed_dict={x:input_data,seq_len:[10]*batch_size})
        time4=time.time()
        print('dynamic_rnn seq_len:10		{}'.format(time4-time3))
     
        for i in range(100):
            sess.run([output2,state2],feed_dict={x:input_data,seq_len:[100]*batch_size})
        time5=time.time()
        print('dynamic_rnn seq_len:100		{}'.format(time5-time4))

    result:

    static_rnn seq_len:10       0.8497538566589355
    static_rnn seq_len:100      1.5897266864776611
    dynamic_rnn seq_len:10      0.4857025146484375
    dynamic_rnn seq_len:100     2.8693313598632812

    序列短的要比序列长的运行的快,dynamic_rnn比static_rnn快的原因是:dynamic_rnn运行到序列长度后自动停止,不再运行,而static_rnn必须运行完num_steps才停止;序列长度为100的实验结果和分析相反,可能是因为循环耗时间,比不上直接在100个LSTM上运行的性能。

    -----------------------------------------------------------------------------------------------

  • 相关阅读:
    day 46
    day 45 JavaScript 下 函数
    day 42 css 样式
    44 JavaScript
    41 前端
    40 协程 i/0多路复用
    39 线程池 同一进程间的队列
    38 线程 锁 事件 信号量 利用线程实现socket 定时器
    37 生产者消费者模型 管道 进程间的数据共享 进程池
    演示使用string对象
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/11109399.html
Copyright © 2011-2022 走看看