攒了几天,发一个大的
这是前几天投了一家量化分析职位,他给的题目的是写神经网络择时模型,大概就是用神经网络预测收盘价
database类:该类用于获得新浪网中的数据,并将其放入本地数据库。在本地数据库中建立两个表,分别是Data2012to2015和Data2015to2016,表中都含有日期,当日开盘价、当日收盘价、当日最高价、当日最低价。Data2012to2015为训练数据集,Data2015to2016为测试数据集。
package it.cast;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
public class dataBase {
//创建训练集:Data2012to2015和测试集Data2015to2016
public void createDataBase() {
try {
Connection conn = null;
Statement stmt = null;
//链接数据库
Class.forName("oracle.jdbc.driver.OracleDriver");
String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
String UserName = "system";
String password = "manager";
conn = DriverManager.getConnection(url, UserName, password);
stmt = conn.createStatement();
historyShare( conn, stmt);
}catch (Exception e) {
e.printStackTrace();
}
}
private void historyShare( Connection conn, Statement stmt)
throws SQLException, MalformedURLException, IOException,
UnsupportedEncodingException {
//创建表格
//表格列为:股票id号、日期、开盘价、最高价、收盘价、最低价、成交量
String sql = "create table Data2015to2016(stokeid integer not null primary key ," +
"data varchar2(20), openPrice varchar2(20), highPrice varchar2(20), overPrice varchar2(20),lowPrice varchar2(20)," +
"vol varchar2(20))";
stmt.executeUpdate(sql);
URL ur = null;
ur = new URL("http://biz.finance.sina.com.cn/stock/flash_hq/kline_data.php?&rand=random(10000)&symbol=sz000001&end_date=20161118&begin_date=20151118&type=plain");
HttpURLConnection uc = (HttpURLConnection) ur.openConnection();
BufferedReader reader = new BufferedReader(new InputStreamReader(ur.openStream(),"GBK"));
String line;
PreparedStatement stmt1 = null;
int i=1;
//插入数据
while((line = reader.readLine()) != null){
//普通股票
String sql1 = "insert into Data2015to2016 values(?,?, ?, ?, ?, ?, ?)";
stmt1 = conn.prepareStatement(sql1);
String[] data=line.split(",");
String date = data[0];
String openPrice = data[1];
String highPrice = data[2];
String overPrice = data[3];
String lowPrice = data[4];
stmt1.setInt(1, i++);
stmt1.setString(2, data[0]);
stmt1.setString(3, data[1]);
stmt1.setString(4, data[2]);
stmt1.setString(5, data[3]);
stmt1.setString(6, data[4]);
stmt1.setString(7, data[5]);
stmt1.executeUpdate();
stmt1.close();
}
}
}
Methods类:由于java没有现成的包可以直接得出某只股票的波动率指标、短期和长期均线指标等指标,由于一些指标在网上没有找到,例如动量和反转指标:REVS5,就用了动量指标MTM。所以在百度百科等资料中搜集了一些公式, 分别对这些公式编写代码,就能观测到的数据来说,是准确的。
最后采用了8个指标,分别是波动率指标:EMV;短期和长期均线指标:EMA5和EMA60,MA5和MA60;动量指标MTM;量能指标:MACD;能量指标:CR5.以这8个指标为自变量,收盘价为因变量建立神经网络模型。
package it.cast;
import java.util.ArrayList;
import java.util.List;
public class Methods {
//搜狗百科:A=(今日最高+今日最低)/2;B=(前日最高+前日最低)/2;C=今日最高-今日最低;2.EM=(A-B)*C/今日成交额;3.EMV=N日内EM的累和;4.MAEMV=EMV的M日简单移动平均.参数N为14,参数M为9
public List<Double> EMV(List<Double>highPrice,List<Double>lowPrice,List<Double>vol){
List<Double>EM = new ArrayList<Double>();
for(int i = 2;i<highPrice.size();i++){
double A = (highPrice.get(i)+lowPrice.get(i))/2;
double B = (highPrice.get(i-2)+lowPrice.get(i-2))/2;
double C = highPrice.get(i)-lowPrice.get(i);
EM.add(((A-B)*C)/vol.get(i));
}
List<Double>EMV = new ArrayList<Double>();
//取N为14,即14日的EM值之和;M为9,即9日的移动平均
int N = 14;
int M = 9;
for(int i = N;i<EM.size()+1;i++){
//14日累和
double sum = 0;
for(int j = i-N;j<i;j++){
sum += EM.get(j);
}
EMV.add(sum);
}
List<Double>MAEMV = new ArrayList<Double>();
for(int i = M;i<EMV.size()+1;i++){
//9日移动平均
double sum = 0;
for(int j = i-M;j<i;j++){
sum += EMV.get(j);
}
sum = sum/M;
MAEMV.add(sum);
}
return MAEMV;
}
//EMA=(当日或当期收盘价-上一日或上期EXPMA)/N+上一日或上期EXPMA,其中,首次上期EXPMA值为上一期收盘价,N为天数。
public List<Double> EMA5(List<Double>overPrice){
//取20121118年收盘价为初始EXPMA
List<Double>EMA5 = new ArrayList<Double>();
for(int i = 0;i<5;i++){
EMA5.add(overPrice.get(i));
}
for(int i = 5;i<overPrice.size();i++){
EMA5.add((overPrice.get(i)-EMA5.get(i-5))/5+EMA5.get(i-5));
}
return EMA5;
}
public List<Double> EMA60(List<Double>overPrice){
//取20121118年收盘价为初始EXPMA
List<Double>EMA60 = new ArrayList<Double>();
for(int i = 0;i<60;i++){
EMA60.add(overPrice.get(i));
}
for(int i = 60;i<overPrice.size();i++){
EMA60.add((overPrice.get(i)-EMA60.get(i-60))/60+EMA60.get(i-60));
}
return EMA60;
}
//5日均线
public List<Double> MA5(List<Double>overPrice){
List<Double>MA5 = new ArrayList<Double>();
for(int i = 5;i<overPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-5;j--){
sum += overPrice.get(j);
}
sum = sum/5;
MA5.add(sum);
}
return MA5;
}
//60日均线
public List<Double> MA60(List<Double>overPrice){
List<Double>MA60 = new ArrayList<Double>();
for(int i = 60;i<overPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-60;j--){
sum += overPrice.get(j);
}
sum = sum/60;
MA60.add(sum);
}
return MA60;
}
//动量指标MTM,1.MTM=当日收盘价-N日前收盘价;2.MTMMA=MTM的M日移动平均;3.参数N一般设置为12日参数M一般设置为6,表中当动量值减低或反转增加时,应为买进或卖出时机
public List<Double> MTM(List<Double>overPrice){
List<Double>MTM = new ArrayList<Double>();
List<Double>MTMlist = new ArrayList<Double>();
int N = 12;
int M = 6;
for(int i = 12;i<overPrice.size();i++){
MTM.add(overPrice.get(i)-overPrice.get(i-12));
}
//移动平均参数为6
for(int i = 6;i<MTM.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-6;j--){
sum += MTM.get(j);
}
sum = sum/6;
MTMlist.add(sum);
}
return MTMlist;
}
//百度百科:http://baike.baidu.com/link?url=XQf2I-JIyNR1AEM_EnMnuU90U1vmJDoXukUe1fQVsBA1Y_fqAA8dj7DoxLCoh5U-YysBkVT5aIZLXeG2g1snoK:量能指标就是通过动态分析成交量的变化,
public List<Double> MACD(List<Double>vol){
int shortN = 12;
List<Double>Short = new ArrayList<Double>();
for(int i = shortN;i<vol.size()+1;i++){
Short.add(2*vol.get(i-1)+(shortN-1)*vol.get(i-shortN));
}
int longN = 26;
List<Double>Long = new ArrayList<Double>();
for(int i = longN;i<vol.size()+1;i++){
Long.add(2*vol.get(i-1)+(longN-1)*vol.get(i-longN));
}
// 取两个序列中较短序列的长度
int length = 0;
if(Short.size()>Long.size()){
length = Long.size();
}else{
length = Short.size();
}
List<Double>DIFF1 = new ArrayList<Double>();
for(int i = length-1;i>=0;i--){
DIFF1.add(Short.get(i)-Long.get(i));
}
List<Double>DIFF = new ArrayList<Double>();
for(int i = 0;i<DIFF1.size();i++){
DIFF.add(DIFF1.get(DIFF1.size()-i-1));
}
List<Double>DEA = new ArrayList<Double>();
for(int i = 0;i<DIFF.size()-1;i++){
DEA.add(2*DIFF.get(i+1)+(9-1)*DIFF.get(i));
}
List<Double>MACD = new ArrayList<Double>();
for(int i = 1;i<DIFF.size();i++){
MACD.add(DIFF.get(i)-DEA.get(i-1));
}
return MACD;
}
//能量指标:CR,见百度百科:http://baike.baidu.com/link?url=v5yYFep6wZioav0P-LOruuhkzjho6PqzQqfEBj5TYQLfaadLSADSQVl0njP7k1zY78KJMoBFrE4OO4wYolZXbMnRRQi7U66R0X2jeSV3ZoXKeuG2zEbqEqP4CnyiF7j6
public List<Double> CR5(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice){
List<Double> YM = new ArrayList<Double>();
List<Double> HYM = new ArrayList<Double>();
List<Double> YML = new ArrayList<Double>();
List<Double> CR = new ArrayList<Double>();
for(int i = 0;i<overPrice.size();i++){
YM.add((highPrice.get(i)+overPrice.get(i)+lowPrice.get(i)+openPrice.get(i))/4);
}
//p1表示5日以来多方力量总和,p2表示5日以来空方力量总和
for(int i = 6;i<highPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-5;j--){
sum += highPrice.get(j)-YM.get(j-1);
}
HYM.add(sum);
}
//p2表示5日以来空方力量总和,p2表示5日以来空方力量总和
for(int i = 6;i<lowPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-5;j--){
sum += YM.get(j-1)-lowPrice.get(j);
}
YML.add(sum);
}
for(int i = 0;i<YML.size();i++){
double temp = (double)HYM.get(i)/YML.get(i);
if(temp<0){
CR.add((double) 0);
}else{
CR.add(temp);
}
}
return CR;
}
public double[][] bpTrain(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice,List<Double>vol){
List<Double>EMV = EMV(highPrice, lowPrice, vol);
List<Double>EMA5 = EMA5(overPrice);
List<Double>EMA60 = EMA60(overPrice);
List<Double>MA5 = MA5(overPrice);
List<Double>MA60 = MA60(overPrice);
List<Double>MTM = MTM(overPrice);
List<Double>MACD = MACD(vol);
List<Double>CR5 = CR5(overPrice, highPrice, lowPrice, openPrice);
int length = 0;
if(EMA60.size()>MA60.size()){
length = MA60.size();
}else{
length = EMA60.size();
}
List<ArrayList<Double>>datalist = new ArrayList<ArrayList<Double>>();
for(int i = 0;i<length;i++){
ArrayList<Double>list = new ArrayList<Double>();
//list.add(EMV.get(EMV.size()-length+i));
list.add(EMA5.get(EMA5.size()-length+i));
list.add(EMA60.get(EMA60.size()-length+i));
list.add(MA5.get(MA5.size()-length+i));
list.add(MA60.get(MA60.size()-length+i));
list.add(MTM.get(MTM.size()-length+i));
// list.add(MACD.get(MACD.size()-length+i));
list.add(CR5.get(CR5.size()-length+i));
datalist.add(list);
}
double [][]data = new double[datalist.size()][6];
for(int i = 0;i<datalist.size();i++){
for(int j = 0;j<6;j++){
data[i][j] = datalist.get(i).get(j);
System.out.print(data[i][j]+" ");
}
System.out.println();
}
return data;
}
}
BPnet类:这里想建立输入单元为8个,两层隐含层,每个隐含层为13个单元,输出层单元为1的神经网络。
首先初始化输入层到隐含层,隐含层之间,以及隐含层到输出层的权重矩阵;
其次利用权重矩阵和输入层分别计算出每个隐含层节点数据
之后利用计算得出的输出层数据与真实值进行比较,并逐层调节权重;
反复上述过程直至精度达到要求或是达到迭代次数的要求;
这里设置迭代次数为5000次;
利用的测试数据集为Data2012to2015
下图为训练之后的模型对Data2012to2015自身进行拟合的效果:(这里由于自变量大概是10左右的数据,所以在利用激活函数1/(1+e^-ax))时,a取了0.01
package it.cast;
import java.util.Random;
public class BPnet {
public double[][] layer;//神经网络各层节点
public double[][] layerErr;//神经网络各节点误差
public double[][][] layer_weight;//各层节点权重
public double[][][] layer_weight_delta;//各层节点权重动量
public double mobp;//动量系数
public double rate;//学习系数
public BPnet(int[] layernum, double rate, double mobp){
this.mobp = mobp;
this.rate = rate;
layer = new double[layernum.length][];
layerErr = new double[layernum.length][];
layer_weight = new double[layernum.length][][];
layer_weight_delta = new double[layernum.length][][];
Random random = new Random();
for(int l=0;l<layernum.length;l++){
layer[l]=new double[layernum[l]];
layerErr[l]=new double[layernum[l]];
if(l+1<layernum.length){
layer_weight[l]=new double[layernum[l]+1][layernum[l+1]];
layer_weight_delta[l]=new double[layernum[l]+1][layernum[l+1]];
for(int j=0;j<layernum[l]+1;j++)
for(int i=0;i<layernum[l+1];i++)
layer_weight[l][j][i]=random.nextDouble();//随机初始化权重
}
}
}
//逐层向前计算输出
public double[] computeOut(double[] in){
for(int l=1;l<layer.length;l++){
for(int j=0;j<layer[l].length;j++){
double z=layer_weight[l-1][layer[l-1].length][j];
for(int i=0;i<layer[l-1].length;i++){
layer[l-1][i]=l==1?in[i]:layer[l-1][i];
z+=layer_weight[l-1][i][j]*layer[l-1][i];
}
// System.out.println(z+"####");
layer[l][j]=1/(1+Math.exp(-0.01*z));
// System.out.println("&&**"+layer[l][j]);
}
}
//System.out.println("&&^^**"+layer[layer.length-1][0]);
return layer[layer.length-1];
}
//逐层反向计算误差并修改权重
public void updateWeight(double[] tar){
int l=layer.length-1;
for(int j=0;j<layerErr[l].length;j++)
layerErr[l][j]=layer[l][j]*(1-layer[l][j])*(1/(1+Math.exp(-0.01*tar[j]))-layer[l][j]);
while(l-->0){
for(int j=0;j<layerErr[l].length;j++){
double z = 0.0;
for(int i=0;i<layerErr[l+1].length;i++){
z=z+l>0?layerErr[l+1][i]*layer_weight[l][j][i]:0;
layer_weight_delta[l][j][i]= mobp*layer_weight_delta[l][j][i]+rate*layerErr[l+1][i]*layer[l][j];//隐含层动量调整
layer_weight[l][j][i]+=layer_weight_delta[l][j][i];//隐含层权重调整
if(j==layerErr[l].length-1){
layer_weight_delta[l][j+1][i]= mobp*layer_weight_delta[l][j+1][i]+rate*layerErr[l+1][i];//截距动量调整
layer_weight[l][j+1][i]+=layer_weight_delta[l][j+1][i];//截距权重调整
}
}
layerErr[l][j]=z*layer[l][j]*(1-layer[l][j]);//记录误差
}
}
}
public void train(double[] in, double[] tar){
double[] out = computeOut(in);
updateWeight(tar);
}
}
从图中可以看出2012年初,股市变化幅度很大时,模型拟合效果稍差,但总体拟合效果较好。(红线表示拟合曲线,蓝线表示真实收盘价)
测试数据集采用的是Data2015to2016,即2015年至2016年数据,拟合拟合效果如下:
从图中可以看出曲线可以拟合大致趋势,但是不能很好的拟合波动,可能是由于对训练数据集过渡拟合的原因。
BackProce类:该类计算了如果按照神经网络模型对该股票进行操作的结果,采用的策略是,如果下一天的预测值高于当天的收盘价,就买入,低于就卖出,设置初始账户金额为10000.
可得到最后的收益率为0.18364521221914928,账户金额为:11836.452122191493。
累计收益率如下图:
累计收益率明显呈现上升趋势。
package it.cast;
import java.util.ArrayList;
import java.util.List;
public class BackProce {
public List<ArrayList<Double>> selectChance(List<ArrayList<Double>>result,double account){
double accountF = account;
System.out.println("初始账户为: "+account);
ArrayList<Double>expect = new ArrayList<Double>();
ArrayList<Double>target = new ArrayList<Double>();
for(int i = 0;i<result.size();i++){
expect.add(result.get(i).get(0));
target.add(result.get(i).get(1));
}
List<ArrayList<Double>>chance = new ArrayList<ArrayList<Double>>();
for(int i = 1;i<expect.size();i++){
if(expect.get(i)>target.get(i-1)){
//买入
account += account*(target.get(i)-target.get(i-1))/target.get(i-1);
}
ArrayList<Double>list = new ArrayList<Double>();
list.add((account-accountF)/accountF);
list.add((double) i);
chance.add(list);
}
System.out.println("期末账户为: "+account);
System.out.println("年化收益率为: "+(account-accountF)/accountF);
return chance;
}
}
辅助类Graph:该类借助了jfree包,用于绘制图像
package it.cast;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Font;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import javax.swing.JPanel;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.category.LineAndShapeRenderer;
import org.jfree.chart.title.TextTitle;
import org.jfree.data.category.CategoryDataset;
import org.jfree.data.category.DefaultCategoryDataset;
import org.jfree.ui.ApplicationFrame;
import org.jfree.ui.HorizontalAlignment;
import org.jfree.ui.RectangleEdge;
public class Graph extends ApplicationFrame{
ChartPanel frame1;
private static final long serialVersionUID = 1L;
public Graph(String s , List<ArrayList<Double>> excel) {
super(s);
setContentPane(createDemoLine(excel));
}
public static DefaultCategoryDataset createDataset(List<ArrayList<Double>> excel) {
DefaultCategoryDataset linedataset = new DefaultCategoryDataset();
for (int i=0; i <excel.size(); i++) {
linedataset.addValue(excel.get(i).get(0), "expect", excel.get(i).get(1));
//linedataset.addValue(excel.get(i).get(1), "target", Integer.toString(i+1));
}
return linedataset;
}
public static JPanel createDemoLine(List<ArrayList<Double>> excel) {
JFreeChart jfreechart = createChart(createDataset(excel));
return new ChartPanel(jfreechart);
}
// 生成图表主对象JFreeChart
public static JFreeChart createChart(DefaultCategoryDataset linedataset) {
// 定义图表对象
JFreeChart chart = ChartFactory.createLineChart("Cumulative rate of return", //折线图名称
"time", // 横坐标名称
"Value", // 纵坐标名称
linedataset, // 数据
PlotOrientation.VERTICAL, // 水平显示图像
true, // include legend
false, // tooltips
false // urls
);
// chart.setBackgroundPaint(Color.red);
CategoryPlot plot = chart.getCategoryPlot();
// plot.setDomainGridlinePaint(Color.red);
plot.setDomainGridlinesVisible(true);
// 5,设置水平网格线颜色
// plot.setRangeGridlinePaint(Color.blue);
// 6,设置是否显示水平网格线
plot.setRangeGridlinesVisible(true);
plot.setRangeGridlinesVisible(true); //是否显示格子线
//plot.setBackgroundAlpha(f); //设置背景透明度
NumberAxis rangeAxis = (NumberAxis)plot.getRangeAxis();
rangeAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());
rangeAxis.setAutoRangeIncludesZero(true);
rangeAxis.setUpperMargin(0.20);
rangeAxis.setLabelAngle(Math.PI / 2.0);
rangeAxis.setAutoRange(false);
FileOutputStream fos_jpg=null;
try{
fos_jpg=new FileOutputStream("D:\ok_bing.jpg");
/*
* 第二个参数如果为100,会报异常:
* java.lang.IllegalArgumentException: The 'quality' must be in the range 0.0f to 1.0f
* 限制quality必须小于等于1,把100改成 0.1f。
*/
// ChartUtilities.writeChartAsJPEG(fos_jpg, 0.99f, chart, 600, 300, null);
ChartUtilities.writeChartAsJPEG(fos_jpg, chart, 900, 400);
}catch(Exception e){
System.out.println("[e]"+e);
}finally{
try{
fos_jpg.close();
}catch(Exception e){
}
}
return chart;
}
}
主函数类testClass:
package it.cast;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class testClass {
public static void main(String[] args) {
dataBase data = new dataBase();
// data.createDataBase();
try{
Connection conn = null;
Statement stmt = null;
//链接数据库
Class.forName("oracle.jdbc.driver.OracleDriver");
String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
String UserName = "system";
String password = "manager";
conn = DriverManager.getConnection(url, UserName, password);
stmt = conn.createStatement();
String sql2="select * from Data2012to2015";
ResultSet rs = stmt.executeQuery(sql2);
//创建序列
List<Double> openPrice = new ArrayList<Double>();
List<Double> highPrice = new ArrayList<Double>();
List<Double> overPrice = new ArrayList<Double>();
List<Double> lowPrice = new ArrayList<Double>();
List<Double> vol = new ArrayList<Double>();
while (rs.next()){
openPrice.add(Double.parseDouble(rs.getString("OPENPRICE")));
highPrice.add(Double.parseDouble(rs.getString("HIGHPRICE")));
overPrice.add(Double.parseDouble(rs.getString("OVERPRICE")));
lowPrice.add(Double.parseDouble(rs.getString("LOWPRICE")));
vol.add(Double.parseDouble(rs.getString("VOL")));
}
Methods m = new Methods();
double [][]dataset = m.bpTrain(overPrice, highPrice, lowPrice, openPrice, vol);
double [][]target = new double[dataset.length][];
for(int i = 0;i<dataset.length;i++){
target[i] = new double[1];
target[i][0] = overPrice.get(overPrice.size()-dataset.length+i);
}
String sql3="select * from Data2015to2016";
ResultSet rs2 = stmt.executeQuery(sql3);
//创建序列
List<Double> openPrice2 = new ArrayList<Double>();
List<Double> highPrice2 = new ArrayList<Double>();
List<Double> overPrice2 = new ArrayList<Double>();
List<Double> lowPrice2 = new ArrayList<Double>();
List<Double> vol2 = new ArrayList<Double>();
while (rs2.next()){
openPrice2.add(Double.parseDouble(rs.getString("OPENPRICE")));
highPrice2.add(Double.parseDouble(rs.getString("HIGHPRICE")));
overPrice2.add(Double.parseDouble(rs.getString("OVERPRICE")));
lowPrice2.add(Double.parseDouble(rs.getString("LOWPRICE")));
vol2.add(Double.parseDouble(rs.getString("VOL")));
}
Methods m2 = new Methods();
double [][]dataset2 = m2.bpTrain(overPrice2, highPrice2, lowPrice2, openPrice2, vol2);
double [][]target2 = new double[dataset2.length][];
for(int i = 0;i<dataset2.length;i++){
target2[i] = new double[1];
target2[i][0] = overPrice2.get(overPrice2.size()-dataset2.length+i);
}
BPnet bp = new BPnet(new int[]{6,13,13,1}, 0.15, 0.8);
//迭代训练5000次
for(int n=0;n<50000;n++)
for(int i=0;i<dataset.length;i++)
bp.train(dataset[i], target[i]);
//测试数据集
double []result = new double[dataset2.length];
List<ArrayList<Double>>resultList = new ArrayList<ArrayList<Double>>();
for(int j=0;j<dataset2.length;j++){
double []a = bp.computeOut(dataset2[j]);
ArrayList<Double>list = new ArrayList<Double>();
result[j] = 100*(-Math.log(1/a[0]-1));
list.add(result[j]);
list.add(target2[j][0]);
resultList.add(list);
System.out.println(Arrays.toString(dataset2[j])+":"+result[j]+" real:"+target2[j][0]);
}
//new Graph("1",resultList);
BackProce b = new BackProce();
double account = 10000;
List<ArrayList<Double>>chance = b.selectChance(resultList,account);
new Graph("1",chance);
}catch (Exception e) {
e.printStackTrace();
// TODO: handle exception
}
System.out.println("End");
}
}
缺点:1、只能绘制基本图像,没有找到方法将特殊点标出,如:能够获取在什么时间点买入,但是不知怎么在特定点用其他颜色标出。
2、神经网络模型对训练数据拟合很好,但是对测试数据拟合效果不佳,猜测原因可能是过拟合或是有些其他主要的变量因素没有考虑进去。