zoukankan      html  css  js  c++  java
  • Tensorflow2(预课程)---11.2、循环神经网络实现股票预测(LSTM)

    Tensorflow2(预课程)---11.2、循环神经网络实现股票预测(LSTM)

    一、总结

    一句话总结:

    将SimpleRNN层换成LSTM层即可,非常简单
    # 测试集变array并reshape为符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
    x_test, y_test = np.array(x_test), np.array(y_test)
    x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
    
    model = tf.keras.Sequential([
        LSTM(80, return_sequences=True),
        Dropout(0.2),
        LSTM(100),
        Dropout(0.2),
        Dense(1)
    ])

    二、循环神经网络实现股票预测(LSTM)

    博客对应课程的视频位置:

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.layers import Dropout, Dense, LSTM
    import matplotlib.pyplot as plt
    import os
    import pandas as pd
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.metrics import mean_squared_error, mean_absolute_error
    import math
    
    maotai = pd.read_csv('./SH600519.csv')  # 读取股票文件
    
    training_set = maotai.iloc[0:2426 - 300, 2:3].values  # 前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
    test_set = maotai.iloc[2426 - 300:, 2:3].values  # 后300天的开盘价作为测试集
    
    # 归一化
    sc = MinMaxScaler(feature_range=(0, 1))  # 定义归一化:归一化到(0,1)之间
    training_set_scaled = sc.fit_transform(training_set)  # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化
    test_set = sc.transform(test_set)  # 利用训练集的属性对测试集进行归一化
    
    x_train = []
    y_train = []
    
    x_test = []
    y_test = []
    
    # 测试集:csv表格中前2426-300=2126天数据
    # 利用for循环,遍历整个训练集,提取训练集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建2426-300-60=2066组数据。
    for i in range(60, len(training_set_scaled)):
        x_train.append(training_set_scaled[i - 60:i, 0])
        y_train.append(training_set_scaled[i, 0])
    # 对训练集进行打乱
    np.random.seed(7)
    np.random.shuffle(x_train)
    np.random.seed(7)
    np.random.shuffle(y_train)
    tf.random.set_seed(7)
    # 将训练集由list格式变为array格式
    x_train, y_train = np.array(x_train), np.array(y_train)
    
    # 使x_train符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
    # 此处整个数据集送入,送入样本数为x_train.shape[0]即2066组数据;输入60个开盘价,预测出第61天的开盘价,循环核时间展开步数为60; 每个时间步送入的特征是某一天的开盘价,只有1个数据,故每个时间步输入特征个数为1
    x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
    # 测试集:csv表格中后300天数据
    # 利用for循环,遍历整个测试集,提取测试集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建300-60=240组数据。
    for i in range(60, len(test_set)):
        x_test.append(test_set[i - 60:i, 0])
        y_test.append(test_set[i, 0])
    # 测试集变array并reshape为符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
    x_test, y_test = np.array(x_test), np.array(y_test)
    x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
    
    model = tf.keras.Sequential([
        LSTM(80, return_sequences=True),
        Dropout(0.2),
        LSTM(100),
        Dropout(0.2),
        Dense(1)
    ])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
                  loss='mean_squared_error')  # 损失函数用均方误差
    # 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
    
    checkpoint_save_path = "./checkpoint/LSTM_stock.ckpt"
    
    if os.path.exists(checkpoint_save_path + '.index'):
        print('-------------load the model-----------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True,
                                                     monitor='val_loss')
    
    history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,
                        callbacks=[cp_callback])
    
    model.summary()
    
    file = open('./weights.txt', 'w')  # 参数提取
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()
    
    ################## predict ######################
    # 测试集输入模型进行预测
    predicted_stock_price = model.predict(x_test)
    # 对预测数据还原---从(0,1)反归一化到原始范围
    predicted_stock_price = sc.inverse_transform(predicted_stock_price)
    # 对真实数据还原---从(0,1)反归一化到原始范围
    real_stock_price = sc.inverse_transform(test_set[60:])
    # 画出真实数据和预测数据的对比曲线
    plt.plot(real_stock_price, color='red', label='MaoTai Stock Price')
    plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price')
    plt.title('MaoTai Stock Price Prediction')
    plt.xlabel('Time')
    plt.ylabel('MaoTai Stock Price')
    plt.legend()
    plt.show()
    
    ##########evaluate##############
    # calculate MSE 均方误差 ---> E[(预测值-真实值)^2] (预测值减真实值求平方后求均值)
    mse = mean_squared_error(predicted_stock_price, real_stock_price)
    # calculate RMSE 均方根误差--->sqrt[MSE]    (对均方误差开方)
    rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price))
    # calculate MAE 平均绝对误差----->E[|预测值-真实值|](预测值减真实值求绝对值后求均值)
    mae = mean_absolute_error(predicted_stock_price, real_stock_price)
    print('均方误差: %.6f' % mse)
    print('均方根误差: %.6f' % rmse)
    print('平均绝对误差: %.6f' % mae)
     
    我的旨在学过的东西不再忘记(主要使用艾宾浩斯遗忘曲线算法及其它智能学习复习算法)的偏公益性质的完全免费的编程视频学习网站: fanrenyi.com;有各种前端、后端、算法、大数据、人工智能等课程。
    博主25岁,前端后端算法大数据人工智能都有兴趣。
    大家有啥都可以加博主联系方式(qq404006308,微信fan404006308)互相交流。工作、生活、心境,可以互相启迪。
    聊技术,交朋友,修心境,qq404006308,微信fan404006308
    26岁,真心找女朋友,非诚勿扰,微信fan404006308,qq404006308
    人工智能群:939687837

    作者相关推荐

  • 相关阅读:
    Java WEB 之页面间传递特殊字符
    c++ using Handle Class Pattern to accomplish implementation hiding
    c++ simple class template example: Stack
    c++ why can't class template hide its implementation in cpp file?
    c++ what happens when a constructor throws an exception and leaves the object in an inconsistent state?
    c++ 用namespace实现java的package的功能
    c++ virtual 和 pure virtual的区别
    c++ istream(ostream)是如何转换为bool的
    c++ 使用boost regex库 总结
    c++ 如何使用第三方的library
  • 原文地址:https://www.cnblogs.com/Renyi-Fan/p/13732545.html
Copyright © 2011-2022 走看看