zoukankan      html  css  js  c++  java
  • Tensorflow RNN_LSTM实例

    RNN的一种类型模型被称为长短期记忆网络(LSTM)。我觉得这是一个有趣的名字。它听起来也意味着:短期模式长期不会被遗忘。

    LSTM的精确实现细节不在本文的范围之内。相信我,如果只学习LSTM模型会分散我们的注意力,因为它还没有确定的标准

     

    我们现在开始我们的教程。首先从编写我们的代码开始,先创建一个新的文件,叫做simple_regression.py。导入相关的库,如步骤1所示。

    步骤1:导入相关库

    import numpy as np

    import tensorflow as tf

    from tensorflow.contrib import rnn

    接着,定义一个类叫做SeriesPredictor。如步骤2所示,构造函数里面设置模型超参数,权重和成本函数。

    步骤2:定义一个类及其构造函数

    class SeriesPredictor:

    def __init__(self, input_dim, seq_size, hidden_dim=10):

    self.input_dim = input_dim //#A

    self.seq_size = seq_size //#A

    self.hidden_dim = hidden_dim //#A

    self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]),name='W_out') //#B

    self.b_out = tf.Variable(tf.random_normal([1]), name='b_out') //#B

    self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim]) //#B

    self.y = tf.placeholder(tf.float32, [None, seq_size]) //#B

    self.cost = tf.reduce_mean(tf.square(self.model() - self.y)) //#C

    self.train_op = tf.train.AdamOptimizer().minimize(self.cost) //#C

    self.saver = tf.train.Saver() //#D

    #A超参数。

    #B权重变量和输入占位符。

    #C成本优化器(cost optimizer)。

    #D辅助操作

    接下来,我们使用TensorFlow的内置RNN模型,名为BasicLSTMCellLSTM单元的隐藏维度是通过时间的隐藏状态的维度。我们可以使用该rnn.dynamic_rnn函数处理这个单元格数据,以检索输出结果。步骤3详细介绍了如何使用TensorFlow来实现使用LSTM的预测模型。

    步骤3:定义RNN模型

    def model(self):

    """

    :param x: inputs of size [T, batch_size, input_size]

    :param W: matrix of fully-connected output layer weights

    :param b: vector of fully-connected output layer biases

    """

    cell = rnn.BasicLSTMCell(self.hidden_dim) #A

    outputs, states = tf.nn.dynamic_rnn(cell, self.x, dtype=tf.float32) #B

    num_examples = tf.shape(self.x)[0]

    W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1])#C

    out = tf.matmul(outputs, W_repeated) + self.b_out

    out = tf.squeeze(out)

    return out

    #A创建一个LSTM单元。

    #B运行输入单元,获取输出和状态的张量。

    #C将输出层计算为完全连接的线性函数。

    通过定义模型和成本函数,我们现在可以实现训练函数,该函数学习给定示例输入/输出对的LSTM权重。如步骤4所示,你打开会话并重复运行优化器。

    另外,你可以使用交叉验证来确定训练模型的迭代次数。在这里我们假设固定数量的epocs

    训练后,将模型保存到文件中,以便稍后加载使用。

    步骤4:在一个数据集上训练模型

    def train(self, train_x, train_y):

    with tf.Session() as sess:

    tf.get_variable_scope().reuse_variables()

    sess.run(tf.global_variables_initializer())

    for i in range(1000): #A

    mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y})

    if i % 100 == 0:

    print(i, mse)

    save_path = self.saver.save(sess, 'model.ckpt')

    print('Model saved to {}'.format(save_path))

    #A训练1000

    我们的模型已经成功地学习了参数。接下来,我们想评估利用其他数据来评估以下预测模型的性能。步骤5加载已保存的模型,并通过馈送一些测试数据以此来运行模型。如果学习的模型在测试数据上表现不佳,那么我们可以尝试调整LSTM单元格的隐藏维数

    步骤5:测试学习的模型

    def test(self, test_x):

    with tf.Session() as sess:

    tf.get_variable_scope().reuse_variables()

    self.saver.restore(sess, './model.ckpt')

    output = sess.run(self.model(), feed_dict={self.x: test_x})

    print(output)

    但为了完善自己的工作,让我们组成一些数据,并尝试训练预测模型。在步骤6中,我们将创建输入序列,称为train_x,和相应的输出序列,称为train_y

    步骤6训练并测试一些虚拟数据

    if __name__ == '__main__':

    predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10)

    train_x = [[[1], [2], [5], [6]],

    [[5], [7], [7], [8]],

    [[3], [4], [5], [7]]]

    train_y = [[1, 3, 7, 11],

    [5, 12, 14, 15],

    [3, 7, 9, 12]]

    predictor.train(train_x, train_y)

    test_x = [[[1], [2], [3], [4]], #A

    [[4], [5], [6], [7]]] #B

    predictor.test(test_x)

    #A预测结果应为1357

    #B预测结果应为491113

    你可以将此预测模型视为黑盒子,并用现实世界的时间数据进行测试。

  • 相关阅读:
    ★★★5230打字慢的解决方法...绝对有用...只需要在手机上轻微的设置一下(转)
    IT公司中最流行的10种编程语言(转)
    Windows下安装Object C开发环境,及Hello Word(转)
    [图]AMD的CPU在VirtualBox中安装Mac OS X 10.6(转)
    How_to_Handle_Pointer_Events_in_a_Custom_Control(转)
    Cannot obtain license for Compiler (feature compiler) with license version >= 2.2(转)
    GNUstep Gorm第一个视窗程序,第一个图形界面,第一个helloworld gui(转)
    Symbian源码分析(转)
    Symbian计算MD5(转)
    Does not support program for platform "WINSCW"
  • 原文地址:https://www.cnblogs.com/smuxiaolei/p/8647207.html
Copyright © 2011-2022 走看看