zoukankan      html  css  js  c++  java
  • TensorFlow-LSTM序列预测

      问题情境:已知某一天内到目前为止股票各个时刻的价格,预测接下来短时间内的价格变化。

    import tushare as ts
    import time
    from collections import namedtuple
    import numpy as np
    import tensorflow as tf
    
    
    class TPPM:
        def __init__(self):
            self.input_size = 10
            self.output_size = 10
            self.lstm_size = 10
            self.learning_rate = 0.001
            self.time_step = 1
    
            self.input = tf.placeholder(tf.float32, shape=(None, self.time_step, self.input_size), name='input')
            self.label = tf.placeholder(tf.float32, shape=(None, self.time_step, self.output_size), name='label')
    
            self.lstm = tf.nn.rnn_cell.BasicLSTMCell(self.lstm_size)
            self.initial_state = self.lstm.zero_state(1, tf.float32)
            self.output, self.final_state = tf.nn.dynamic_rnn(self.lstm, self.input, initial_state=self.initial_state, dtype=tf.float32)
    
            self.loss = tf.reduce_mean(tf.square(tf.reshape(self.output, [-1]) - tf.reshape(self.label, [-1])))
    
            self.optimizer = tf.train.AdamOptimizer(1e-4).minimize(self.loss)
    
        def getData(self):
            d = ts.get_tick_data('601318', date='2017-10-16')
            sequence = np.array([row.price - 57.0 for row in d.itertuples()], dtype=np.float32)
            return sequence
        def train(self):
            sequence = self.getData()
            print(sequence)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
    
                for i in range(500):
                    for j in range(0,4500,10):
                        feed_input=sequence[j:j+10].reshape(-1,1,10)
                        feed_label=sequence[j+10:j+20].reshape(-1,1,10)
                        _, loss_ = sess.run([self.optimizer, self.loss], feed_dict={self.input: feed_input, self.label: feed_label})
                    print(i,',',loss_)
                file = open('data/predict.csv', 'w')
                with file:
                    for j in range(0, 4500, 10):
                        feed_input = sequence[j:j + 10].reshape(-1, 1, 10)
                        feed_label = sequence[j + 10:j + 20].reshape(-1, 1, 10)
                        predict = sess.run([self.output], feed_dict={self.input: feed_input, self.label: feed_label})
                        for i in range(10):
                            file.write("%d,%.4f,%.4f
    " % (j + i, np.array(feed_label).reshape(-1)[i], np.array(predict).reshape(-1)[i]))
    
        def restore(self):
            pass
        def save(self):
            pass
    
    model=TPPM()
    model.train()

      运行结果:

      结构比较简单,训练次数也不多,可以看到结果还是比较令人失望的,不过勉强好像是有那么点意思。

      先说__init__:先定义了几个超参数,比较好理解。接下来是输入和标签的占位符。

    self.lstm = tf.nn.rnn_cell.BasicLSTMCell(self.lstm_size)
    self.initial_state = self.lstm.zero_state(1, tf.float32)
    self.output, self.final_state = tf.nn.dynamic_rnn(self.lstm, self.input, initial_state=self.initial_state, dtype=tf.float32)

      以上代码定义LSTM层,这里要注意两点:

      1)第三行的dynamic_rnn里的第二个参数,至少是3维的。第一维是batch_size,表示一个批次处理多少组数据。第二维是time_step,表示序列长度,因为做的是序列预测嘛,每次的输入不是上一个点,而是上一段时间,time_step就表示取得这一段时间内有多少个点。第三及之后的维度是具体描述这个点是什么状况,看具体问题。我的代码里是把相邻的10个时刻里的价格作为一个点,time_step为1,表示我每次输入的序列只由一个点组成,但这个维度还是要有的,batch_size根据feed的情况具体计算,我为了简单,batch_size这里其实也是1

      2)用这个BasicLSTMCell的时候,lstm_size要和input的shape对应起来,具体是什么关系现在还不是很清楚,我这里因为描述部分只由一维所以只要lstm_size=input_size就可以了。

      getData:使用tushare获取股票数据。

      train就是训练,注意feed的数据和模型定义时的数据shape对应上就可以了。

      现在有两个比较大的问题,对信心比较有影响:

      1)具体效果好不好其实现在也不是完全确定,因为之前做天池的口碑商家那个比赛的时候直接就输出近期内的平均值就能得到一个不错的结果。相比于这里来说,我倒看不出有什么明显的优势,单从截图部分来看,还不如平均数呢。

      2)既然循环神经网络是历史输入会对之后造成影响,那么每次把全部数据训练完成后重新回到数据的开始部分进行下一轮训练时,相当于数据来了个突变,这会不会产生什么不好的影响呢。

    ------------------------------------------------------------

      如果用这个程序来指导投机行为,结果会如何呢?我试了一下,假设我初始有1000元,每次预测完接下来的10个时刻的价格后取其中最高价与输入的最后一个时刻价格比较,如果高就进行一次买入卖出行为:

    cnt=1000.0
    with file:
        for j in range(0, 4500, 10):
            feed_input = sequence[j:j + 10].reshape(-1, 1, 10)
            feed_label = sequence[j + 10:j + 20].reshape(-1, 1, 10)
            predict = sess.run([self.output], feed_dict={self.input: feed_input, self.label: feed_label})
            max=0
            for i in range(10):
                if (np.array(predict).reshape(-1)[i]>np.array(predict).reshape(-1)[max]):
                    max=i
                file.write("%d,%.4f,%.4f
    " % (j + i, np.array(feed_label).reshape(-1)[i], np.array(predict).reshape(-1)[i]))
            if (np.array(predict).reshape(-1)[max]>np.array(feed_input).reshape(-1)[9]):
                cnt=cnt/(np.array(feed_input).reshape(-1)[9]+57.0)*(np.array(feed_label).reshape(-1)[max]+57.0)
    print(cnt)

     模拟了一天的操作后,赚了3块钱:

  • 相关阅读:
    使用PHPMYADMIN添加新用户和数据库
    phpMyadmin用户权限中英对照
    Asp生成xml乱码解放方法
    SQL Server 错误日志
    安装VS2005 SP1时失败(错误 1718。文件被数字签名策略拒绝)的解决办法!
    CKEditor 3 JavaScript API Documentation
    CKEditor在.NET中的应用
    IIS 添加网站显示错误消息 “无更多可用的内存以更新安全信息” 解决方法
    jQuery 操作Cookie
    JavaScript/HTML格式化
  • 原文地址:https://www.cnblogs.com/dramstadt/p/7694320.html
Copyright © 2011-2022 走看看