zoukankan      html  css  js  c++  java
  • 两层LSTM的使用

    一层的lstm效果不是很好,使用两层的lstm,代码如下。

     1 with graph.as_default():
     2     inputs_ = tf.placeholder(tf.int32, [None, seq_len], name='inputs')
     3     labels_ = tf.placeholder(tf.int32, [None, 4], name='labels')
     4     keep_prob = tf.placeholder(tf.float32, name='keep_prob')
     5 
     6     embedding = tf.Variable(tf.random_uniform((n_words + 1, embed_size), -1, 1))
     7     embed = tf.nn.embedding_lookup(embedding, inputs_)
     8     ################################# R N N #################################
     9     def LSTM_with_drop(lstm_size, keep_prob):
    10         # Your basic LSTM cell
    11         lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size, state_is_tuple=True)
    12 
    13         # Add dropout to the cell
    14         drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    15         return drop
    16     
    17     fw_cell = tf.nn.rnn_cell.MultiRNNCell( [LSTM_with_drop(lstm_size, keep_prob) for _ in range(lstm_layers)] )
    18     bw_cell = tf.nn.rnn_cell.MultiRNNCell( [LSTM_with_drop(lstm_size, keep_prob) for _ in range(lstm_layers)] )
    19     # Getting an initial state of all zeros
    20     fw_initial_state = fw_cell.zero_state(batch_size, tf.float32)
    21     bw_initial_state = bw_cell.zero_state(batch_size, tf.float32)
    22     
    23     outputs, final_state = tf.nn.bidirectional_dynamic_rnn(inputs=embed, 
    24                                                            cell_fw=fw_cell, 
    25                                                            cell_bw=fw_cell, 
    26                                                            initial_state_fw=fw_initial_state, 
    27                                                            initial_state_bw=bw_initial_state)
    28     state = tf.concat([outputs[0][:,-1], outputs[1][:,-1]], 1)
    29     ################################# R N N #################################
  • 相关阅读:
    python-web 创建一个输入链接生成的网站
    查看端口有没被占用
    bs的过滤器功能例子
    爬图片的方法
    python 下载图片的方法
    request 里面参数设置 (有空瞄下)
    python 面向对象 初始化(类变量 和 函数内变量)
    访问https请求出现警告,去掉警告的方法
    find 和 find_all 用法
    D3的基本设计思路
  • 原文地址:https://www.cnblogs.com/demo-deng/p/10195812.html
Copyright © 2011-2022 走看看