zoukankan      html  css  js  c++  java
  • Python基于皮尔逊系数实现股票预测

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Mon Dec  2 14:49:59 2018
     4 
     5 @author: zhen
     6 """
     7 
     8 import matplotlib.pyplot as plt
     9 import numpy as np
    10 import pandas as pd
    11 from datetime import datetime
    12 
    13 def normal(a):  #最大值最小值归一化
    14     return (a - np.min(a)) / (np.max(a) - np.min(a)+0.000001)
    15 
    16 def normalization(x): # np.std:计算矩阵的标准差(方差的算术平方根)
    17     return (x - np.mean(x)) / np.std(x)
    18 
    19 def corrcoef(a,b):
    20     corrc = np.corrcoef(a,b) # 计算皮尔逊相关系数,用于度量两个变量之间的相关性,其值介于-1到1之间
    21     corrc = corrc[0,1]
    22     return (16 * ((1 - corrc) / (1 + corrc)) ** 1) # ** 表示乘方
    23   
    24 startTimeStamp = datetime.now() # 获取当前时间
    25 # 加载数据
    26 filename = 'C:/Users/zhen/.spyder-py3/sh000300_2017.csv'
    27 # 获取第一,二列的数据
    28 all_date = pd.read_csv(filename,usecols=[0, 1, 3], dtype = 'str')
    29 all_date = np.array(all_date)
    30 data = all_date[:, 0]
    31 times = all_date[:, 1]
    32 
    33 data_points = pd.read_csv(filename,usecols=[3])
    34 data_points = np.array(data_points)
    35 data_points = data_points[:,0] #数据
    36 
    37 topk = 10 #只显示top-10
    38 baselen = 100
    39 basebegin = 361
    40 basedata = data[basebegin]+' '+times[basebegin]+'~'+data[basebegin+baselen-1]+' '+times[basebegin+baselen-1]
    41 base = data_points[basebegin:basebegin+baselen]#一天的数据是240个点
    42 length = len(data_points) #数据长度
    43 
    44 # 分割片段
    45 subseries = []
    46 dateseries = []
    47 for j in range(0,length): 
    48     if (j < (basebegin - baselen) or j > (basebegin + baselen - 1)) and j <length - baselen:
    49         subseries.append(data_points[j:j+baselen])
    50         dateseries.append(j) #开始位置
    51 
    52 # 片段搜索
    53 listdistance = []
    54 for i in range(0, len(subseries)):
    55     tt = np.array(subseries[i])
    56     distance = corrcoef(base, tt)
    57     listdistance.append(distance)
    58 
    59 # 排序
    60 index = np.argsort(listdistance,kind='quicksort') #排序,返回排序后的索引序列
    61 
    62 # 显示,要匹配的数据
    63 plt.figure(0)
    64 plt.plot((base),label = basedata, linewidth='2')
    65 plt.legend(loc='upper left')
    66 plt.title('Base data')
    67 
    68 # 原始数据
    69 plt.figure(1)
    70 num = index[0]
    71 length = len(subseries[num])
    72 begin = data[dateseries[num]]+' '+times[dateseries[num]]
    73 end = data[dateseries[num]+length-1]+' '+times[dateseries[num]+length-1]
    74 label = begin+'~'+end
    75 plt.plot((subseries[num]), label=label, linewidth='2')
    76 plt.legend(loc='upper left')
    77 plt.title('Similarity data')
    78 
    79 # 结果集对比
    80 plt.figure(2)
    81 plt.plot(normalization(base),label= basedata,linewidth='2')
    82 length = len(subseries[num])
    83 begin = data[dateseries[num]] + ' ' + times[dateseries[num]]
    84 end = data[dateseries[num] + length - 1] + ' ' + times[dateseries[num] + length - 1]
    85 label = begin + '~' + end
    86 plt.plot(normalization(subseries[num]), label=label, linewidth='3')  
    87 plt.legend(loc='lower right')
    88 plt.title('normal similarity search')
    89 plt.show()
    90 
    91 endTimeStamp=datetime.now()
    92 print('run time', (endTimeStamp-startTimeStamp).seconds, "s")

    结果:

  • 相关阅读:
    Android自定义之ScrollView下拉刷新
    android Viewpager取消预加载及Fragment方法的学习
    Android上下左右滑动,显示底层布局
    android权限大全
    android学习之VelocityTracker
    Android之自定义(上方标题随ViewPager手势慢慢滑动)
    Red Hat Enterprise Linux 7.5安装极点五笔
    Red Hat Enterprise Linux 7.5安装盘内容做本地YUM源
    RHEL7+Oracle11g笔记
    CentOS安装VNC方法
  • 原文地址:https://www.cnblogs.com/yszd/p/10058475.html
Copyright © 2011-2022 走看看