zoukankan      html  css  js  c++  java
  • 利用神经网络预测股票收盘价(含源代码)

    攒了几天,发一个大的

    这是前几天投了一家量化分析职位,他给的题目的是写神经网络择时模型,大概就是用神经网络预测收盘价

    database:该类用于获得新浪网中的数据,并将其放入本地数据库。在本地数据库中建立两个表,分别是Data2012to2015和Data2015to2016,表中都含有日期,当日开盘价、当日收盘价、当日最高价、当日最低价。Data2012to2015为训练数据集,Data2015to2016为测试数据集。

    package it.cast;

    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.io.UnsupportedEncodingException;
    import java.net.HttpURLConnection;
    import java.net.MalformedURLException;
    import java.net.URL;
    import java.sql.Connection;
    import java.sql.DriverManager;
    import java.sql.PreparedStatement;
    import java.sql.SQLException;
    import java.sql.Statement;

    public class dataBase {
        //创建训练集:Data2012to2015和测试集Data2015to2016
        public  void createDataBase() {
        try {
            Connection conn = null;
            Statement stmt = null;
            //链接数据库
            Class.forName("oracle.jdbc.driver.OracleDriver");
            String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
            String UserName = "system";    
            String password = "manager";
            conn = DriverManager.getConnection(url, UserName, password);
            stmt = conn.createStatement();
            historyShare( conn, stmt);
            
        }catch (Exception e) {

                e.printStackTrace();

            }
        }

        private void historyShare( Connection conn, Statement stmt)
                throws SQLException, MalformedURLException, IOException,
                UnsupportedEncodingException {
            //创建表格
            //表格列为:股票id号、日期、开盘价、最高价、收盘价、最低价、成交量
            String sql = "create table Data2015to2016(stokeid integer not null primary key ," +
                    "data varchar2(20), openPrice varchar2(20), highPrice varchar2(20), overPrice varchar2(20),lowPrice varchar2(20)," +
                    "vol varchar2(20))";

            stmt.executeUpdate(sql);


            URL ur = null;
            ur = new URL("http://biz.finance.sina.com.cn/stock/flash_hq/kline_data.php?&rand=random(10000)&symbol=sz000001&end_date=20161118&begin_date=20151118&type=plain");

            HttpURLConnection uc = (HttpURLConnection) ur.openConnection();

            BufferedReader reader = new BufferedReader(new InputStreamReader(ur.openStream(),"GBK"));
            String line;
            PreparedStatement stmt1 = null;
            int i=1;
            //插入数据
            while((line = reader.readLine()) != null){
                //普通股票
                String sql1 = "insert into Data2015to2016 values(?,?, ?, ?, ?, ?, ?)";
                stmt1 = conn.prepareStatement(sql1);
                String[] data=line.split(",");
                String date = data[0];
                String openPrice = data[1];
                String highPrice = data[2];
                String overPrice = data[3];
                String lowPrice = data[4];
                stmt1.setInt(1, i++);
                stmt1.setString(2, data[0]);
                stmt1.setString(3, data[1]);
                stmt1.setString(4, data[2]);
                stmt1.setString(5, data[3]);
                stmt1.setString(6, data[4]);
                stmt1.setString(7, data[5]);
                stmt1.executeUpdate();
                stmt1.close();
            }
        }

    }

    Methods:由于java没有现成的包可以直接得出某只股票的波动率指标、短期和长期均线指标等指标,由于一些指标在网上没有找到,例如动量和反转指标:REVS5,就用了动量指标MTM。所以在百度百科等资料中搜集了一些公式, 分别对这些公式编写代码,就能观测到的数据来说,是准确的。

    最后采用了8个指标,分别是波动率指标:EMV;短期和长期均线指标:EMA5和EMA60,MA5和MA60;动量指标MTM;量能指标:MACD;能量指标:CR5.以这8个指标为自变量,收盘价为因变量建立神经网络模型。

    package it.cast;

    import java.util.ArrayList;
    import java.util.List;

    public class Methods {
        
        

        //搜狗百科:A=(今日最高+今日最低)/2;B=(前日最高+前日最低)/2;C=今日最高-今日最低;2.EM=(A-B)*C/今日成交额;3.EMV=N日内EM的累和;4.MAEMV=EMV的M日简单移动平均.参数N为14,参数M为9
        public List<Double> EMV(List<Double>highPrice,List<Double>lowPrice,List<Double>vol){
            List<Double>EM = new ArrayList<Double>();
            for(int i = 2;i<highPrice.size();i++){
                double A = (highPrice.get(i)+lowPrice.get(i))/2;
                double B = (highPrice.get(i-2)+lowPrice.get(i-2))/2;
                double C = highPrice.get(i)-lowPrice.get(i);
                EM.add(((A-B)*C)/vol.get(i));
            }

            List<Double>EMV = new ArrayList<Double>();
            //取N为14,即14日的EM值之和;M为9,即9日的移动平均
            int N = 14;
            int M = 9;
            for(int i = N;i<EM.size()+1;i++){
                //14日累和
                double sum = 0;
                for(int j = i-N;j<i;j++){
                    sum += EM.get(j);
                }
                EMV.add(sum);
            }

            List<Double>MAEMV = new ArrayList<Double>();
            for(int i = M;i<EMV.size()+1;i++){
                //9日移动平均
                double sum = 0;
                for(int j = i-M;j<i;j++){
                    sum += EMV.get(j);
                }
                sum = sum/M;
                MAEMV.add(sum);
            }
            return MAEMV;
        }

        //EMA=(当日或当期收盘价-上一日或上期EXPMA)/N+上一日或上期EXPMA,其中,首次上期EXPMA值为上一期收盘价,N为天数。
        public List<Double> EMA5(List<Double>overPrice){
            //取20121118年收盘价为初始EXPMA
            List<Double>EMA5 = new ArrayList<Double>();
            for(int i = 0;i<5;i++){
                EMA5.add(overPrice.get(i));
            }
            for(int i = 5;i<overPrice.size();i++){
                EMA5.add((overPrice.get(i)-EMA5.get(i-5))/5+EMA5.get(i-5));

            }
            return EMA5;
        }


        public List<Double> EMA60(List<Double>overPrice){
            //取20121118年收盘价为初始EXPMA
            List<Double>EMA60 = new ArrayList<Double>();
            for(int i = 0;i<60;i++){
                EMA60.add(overPrice.get(i));
            }
            for(int i = 60;i<overPrice.size();i++){
                EMA60.add((overPrice.get(i)-EMA60.get(i-60))/60+EMA60.get(i-60));
            }
            return EMA60;
        }

        //5日均线
        public List<Double> MA5(List<Double>overPrice){
            List<Double>MA5 = new ArrayList<Double>();
            for(int i = 5;i<overPrice.size()+1;i++){
                double sum = 0;
                for(int j = i-1;j>=i-5;j--){
                    sum += overPrice.get(j);
                }
                sum = sum/5;
                MA5.add(sum);
            }
            return MA5;
        }


        //60日均线
        public List<Double> MA60(List<Double>overPrice){
            List<Double>MA60 = new ArrayList<Double>();
            for(int i = 60;i<overPrice.size()+1;i++){
                double sum = 0;
                for(int j = i-1;j>=i-60;j--){
                    sum += overPrice.get(j);
                }
                sum = sum/60;
                MA60.add(sum);
            }
            return MA60;
        }

        //动量指标MTM,1.MTM=当日收盘价-N日前收盘价;2.MTMMA=MTM的M日移动平均;3.参数N一般设置为12日参数M一般设置为6,表中当动量值减低或反转增加时,应为买进或卖出时机
        public List<Double> MTM(List<Double>overPrice){
            List<Double>MTM = new ArrayList<Double>();
            List<Double>MTMlist = new ArrayList<Double>();
            int N = 12;
            int M = 6;
            for(int i = 12;i<overPrice.size();i++){
                MTM.add(overPrice.get(i)-overPrice.get(i-12));
            }
            
            //移动平均参数为6
            for(int i = 6;i<MTM.size()+1;i++){
                double sum = 0;
                for(int j = i-1;j>=i-6;j--){
                    sum += MTM.get(j);
                }
                sum = sum/6;
                MTMlist.add(sum);
            }
            return MTMlist;
        }
        
        
        //百度百科:http://baike.baidu.com/link?url=XQf2I-JIyNR1AEM_EnMnuU90U1vmJDoXukUe1fQVsBA1Y_fqAA8dj7DoxLCoh5U-YysBkVT5aIZLXeG2g1snoK:量能指标就是通过动态分析成交量的变化,
        public List<Double> MACD(List<Double>vol){
            int shortN = 12;
            List<Double>Short = new ArrayList<Double>();
            for(int i = shortN;i<vol.size()+1;i++){
                Short.add(2*vol.get(i-1)+(shortN-1)*vol.get(i-shortN));
            }
            int longN = 26;
            List<Double>Long = new ArrayList<Double>();
            for(int i = longN;i<vol.size()+1;i++){
                Long.add(2*vol.get(i-1)+(longN-1)*vol.get(i-longN));
            }
            
            //    取两个序列中较短序列的长度
            int length = 0;
            if(Short.size()>Long.size()){
                length = Long.size();
            }else{
                length = Short.size();
            }
            
            List<Double>DIFF1 = new ArrayList<Double>();
            for(int i = length-1;i>=0;i--){
                DIFF1.add(Short.get(i)-Long.get(i));
            }
            List<Double>DIFF = new ArrayList<Double>();
            for(int i = 0;i<DIFF1.size();i++){
                DIFF.add(DIFF1.get(DIFF1.size()-i-1));
            }
            List<Double>DEA = new ArrayList<Double>();
            for(int i = 0;i<DIFF.size()-1;i++){
                DEA.add(2*DIFF.get(i+1)+(9-1)*DIFF.get(i));
            }
            
            List<Double>MACD = new ArrayList<Double>();
            for(int i = 1;i<DIFF.size();i++){
                MACD.add(DIFF.get(i)-DEA.get(i-1));
            }
            return MACD;
        }
        
        
        //能量指标:CR,见百度百科:http://baike.baidu.com/link?url=v5yYFep6wZioav0P-LOruuhkzjho6PqzQqfEBj5TYQLfaadLSADSQVl0njP7k1zY78KJMoBFrE4OO4wYolZXbMnRRQi7U66R0X2jeSV3ZoXKeuG2zEbqEqP4CnyiF7j6
        public List<Double> CR5(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice){
            List<Double> YM = new ArrayList<Double>();
            List<Double> HYM = new ArrayList<Double>();
            List<Double> YML = new ArrayList<Double>();
            List<Double> CR = new ArrayList<Double>();
            for(int i = 0;i<overPrice.size();i++){
                YM.add((highPrice.get(i)+overPrice.get(i)+lowPrice.get(i)+openPrice.get(i))/4);
            }
            //p1表示5日以来多方力量总和,p2表示5日以来空方力量总和
            for(int i = 6;i<highPrice.size()+1;i++){
                double sum = 0;
                for(int j = i-1;j>=i-5;j--){
                    sum += highPrice.get(j)-YM.get(j-1);
                }
                HYM.add(sum);
            }
            //p2表示5日以来空方力量总和,p2表示5日以来空方力量总和
            for(int i = 6;i<lowPrice.size()+1;i++){
                double sum = 0;
                for(int j = i-1;j>=i-5;j--){
                    sum += YM.get(j-1)-lowPrice.get(j);
                }
                YML.add(sum);
            }
            for(int i = 0;i<YML.size();i++){
                double temp = (double)HYM.get(i)/YML.get(i);
                if(temp<0){
                    CR.add((double) 0);
                }else{
                    CR.add(temp);
                }
                
            }
            return CR;
                    

        }
        
        public double[][] bpTrain(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice,List<Double>vol){
            List<Double>EMV = EMV(highPrice, lowPrice, vol);
            List<Double>EMA5 = EMA5(overPrice);
            List<Double>EMA60 = EMA60(overPrice);
            List<Double>MA5 = MA5(overPrice);
            List<Double>MA60 = MA60(overPrice);
            List<Double>MTM = MTM(overPrice);
            List<Double>MACD = MACD(vol);
            List<Double>CR5 = CR5(overPrice, highPrice, lowPrice, openPrice);
            
            int length = 0;
            if(EMA60.size()>MA60.size()){
                length = MA60.size();
            }else{
                length = EMA60.size();
            }
            List<ArrayList<Double>>datalist = new ArrayList<ArrayList<Double>>();
            for(int i = 0;i<length;i++){
                ArrayList<Double>list = new ArrayList<Double>();
                //list.add(EMV.get(EMV.size()-length+i));
                list.add(EMA5.get(EMA5.size()-length+i));
                list.add(EMA60.get(EMA60.size()-length+i));
                list.add(MA5.get(MA5.size()-length+i));
                list.add(MA60.get(MA60.size()-length+i));
                list.add(MTM.get(MTM.size()-length+i));
        //        list.add(MACD.get(MACD.size()-length+i));
                list.add(CR5.get(CR5.size()-length+i));
                datalist.add(list);
            }
            double [][]data = new double[datalist.size()][6];
            for(int i = 0;i<datalist.size();i++){
                for(int j = 0;j<6;j++){
                    data[i][j] = datalist.get(i).get(j);
                    System.out.print(data[i][j]+"  ");
                }
                System.out.println();
            }
            return data;
        }
        
    }

    BPnet:这里想建立输入单元为8个,两层隐含层,每个隐含层为13个单元,输出层单元为1的神经网络。

    首先初始化输入层到隐含层,隐含层之间,以及隐含层到输出层的权重矩阵;

    其次利用权重矩阵和输入层分别计算出每个隐含层节点数据

    之后利用计算得出的输出层数据与真实值进行比较,并逐层调节权重;

    反复上述过程直至精度达到要求或是达到迭代次数的要求;

    这里设置迭代次数为5000次;

    利用的测试数据集为Data2012to2015

    下图为训练之后的模型对Data2012to2015自身进行拟合的效果:(这里由于自变量大概是10左右的数据,所以在利用激活函数1/(1+e^-ax))时,a取了0.01

    package it.cast;

    import java.util.Random;

    public class BPnet {
        public double[][] layer;//神经网络各层节点
        public double[][] layerErr;//神经网络各节点误差
        public double[][][] layer_weight;//各层节点权重
        public double[][][] layer_weight_delta;//各层节点权重动量
        public double mobp;//动量系数
        public double rate;//学习系数

        public BPnet(int[] layernum, double rate, double mobp){
            this.mobp = mobp;
            this.rate = rate;
            layer = new double[layernum.length][];
            layerErr = new double[layernum.length][];
            layer_weight = new double[layernum.length][][];
            layer_weight_delta = new double[layernum.length][][];
            Random random = new Random();
            for(int l=0;l<layernum.length;l++){
                layer[l]=new double[layernum[l]];
                layerErr[l]=new double[layernum[l]];
                if(l+1<layernum.length){
                    layer_weight[l]=new double[layernum[l]+1][layernum[l+1]];
                    layer_weight_delta[l]=new double[layernum[l]+1][layernum[l+1]];
                    for(int j=0;j<layernum[l]+1;j++)
                        for(int i=0;i<layernum[l+1];i++)
                            layer_weight[l][j][i]=random.nextDouble();//随机初始化权重
                }   
            }
        }
        //逐层向前计算输出
        public double[] computeOut(double[] in){
            for(int l=1;l<layer.length;l++){
                for(int j=0;j<layer[l].length;j++){
                    double z=layer_weight[l-1][layer[l-1].length][j];
                    for(int i=0;i<layer[l-1].length;i++){
                        layer[l-1][i]=l==1?in[i]:layer[l-1][i];
                        z+=layer_weight[l-1][i][j]*layer[l-1][i];
                    }
                //    System.out.println(z+"####");
                    
                    layer[l][j]=1/(1+Math.exp(-0.01*z));
                //    System.out.println("&&**"+layer[l][j]);
                    
                    
                }
            }
          //System.out.println("&&^^**"+layer[layer.length-1][0]);
            return layer[layer.length-1];
        }
        //逐层反向计算误差并修改权重
        public void updateWeight(double[] tar){
            int l=layer.length-1;
            for(int j=0;j<layerErr[l].length;j++)
                layerErr[l][j]=layer[l][j]*(1-layer[l][j])*(1/(1+Math.exp(-0.01*tar[j]))-layer[l][j]);

            while(l-->0){
                for(int j=0;j<layerErr[l].length;j++){
                    double z = 0.0;
                    for(int i=0;i<layerErr[l+1].length;i++){
                        z=z+l>0?layerErr[l+1][i]*layer_weight[l][j][i]:0;
                        layer_weight_delta[l][j][i]= mobp*layer_weight_delta[l][j][i]+rate*layerErr[l+1][i]*layer[l][j];//隐含层动量调整
                        layer_weight[l][j][i]+=layer_weight_delta[l][j][i];//隐含层权重调整
                        if(j==layerErr[l].length-1){
                            layer_weight_delta[l][j+1][i]= mobp*layer_weight_delta[l][j+1][i]+rate*layerErr[l+1][i];//截距动量调整
                            layer_weight[l][j+1][i]+=layer_weight_delta[l][j+1][i];//截距权重调整
                        }
                    }
                    layerErr[l][j]=z*layer[l][j]*(1-layer[l][j]);//记录误差
                }
            }
        }

        public void train(double[] in, double[] tar){
            double[] out = computeOut(in);
            updateWeight(tar);
        }
    }

     

    从图中可以看出2012年初,股市变化幅度很大时,模型拟合效果稍差,但总体拟合效果较好。(红线表示拟合曲线,蓝线表示真实收盘价)

    测试数据集采用的是Data2015to2016,即2015年至2016年数据,拟合拟合效果如下:

     

    从图中可以看出曲线可以拟合大致趋势,但是不能很好的拟合波动,可能是由于对训练数据集过渡拟合的原因。

    BackProce:该类计算了如果按照神经网络模型对该股票进行操作的结果,采用的策略是,如果下一天的预测值高于当天的收盘价,就买入,低于就卖出,设置初始账户金额为10000.

    可得到最后的收益率为0.18364521221914928,账户金额为:11836.452122191493。

    累计收益率如下图:

     

    累计收益率明显呈现上升趋势。

     

    package it.cast;

    import java.util.ArrayList;
    import java.util.List;

    public class BackProce {
        
        public List<ArrayList<Double>> selectChance(List<ArrayList<Double>>result,double account){
            double accountF = account;
            System.out.println("初始账户为: "+account);
            ArrayList<Double>expect = new ArrayList<Double>();
            ArrayList<Double>target = new ArrayList<Double>();
            for(int i = 0;i<result.size();i++){
                expect.add(result.get(i).get(0));
                target.add(result.get(i).get(1));
            }
            List<ArrayList<Double>>chance = new ArrayList<ArrayList<Double>>();
            for(int i = 1;i<expect.size();i++){
                if(expect.get(i)>target.get(i-1)){
                    //买入
                    account += account*(target.get(i)-target.get(i-1))/target.get(i-1);
                }
                ArrayList<Double>list = new ArrayList<Double>();
                list.add((account-accountF)/accountF);
                list.add((double) i);
                chance.add(list);
            }
            System.out.println("期末账户为: "+account);
            System.out.println("年化收益率为: "+(account-accountF)/accountF);
            return chance;
        }
    }

    辅助类Graph:该类借助了jfree包,用于绘制图像

    package it.cast;

    import java.awt.BasicStroke;
    import java.awt.Color;
    import java.awt.Font;
    import java.io.FileOutputStream;
    import java.io.OutputStream;
    import java.util.ArrayList;
    import java.util.Date;
    import java.util.List;

    import javax.swing.JPanel;

    import org.jfree.chart.ChartFactory;
    import org.jfree.chart.ChartPanel;
    import org.jfree.chart.ChartUtilities;
    import org.jfree.chart.JFreeChart;
    import org.jfree.chart.axis.NumberAxis;
    import org.jfree.chart.plot.CategoryPlot;
    import org.jfree.chart.plot.PlotOrientation;
    import org.jfree.chart.renderer.category.LineAndShapeRenderer;
    import org.jfree.chart.title.TextTitle;
    import org.jfree.data.category.CategoryDataset;
    import org.jfree.data.category.DefaultCategoryDataset;
    import org.jfree.ui.ApplicationFrame;
    import org.jfree.ui.HorizontalAlignment;
    import org.jfree.ui.RectangleEdge;

    public class Graph extends ApplicationFrame{
        ChartPanel frame1;  
        private static final long serialVersionUID = 1L;
        
        public Graph(String s , List<ArrayList<Double>> excel) {
           super(s);
           setContentPane(createDemoLine(excel));
        }
        
        public static DefaultCategoryDataset createDataset(List<ArrayList<Double>> excel) {
            DefaultCategoryDataset linedataset = new DefaultCategoryDataset();
            for (int i=0; i <excel.size(); i++) {
                linedataset.addValue(excel.get(i).get(0), "expect", excel.get(i).get(1));
                //linedataset.addValue(excel.get(i).get(1), "target", Integer.toString(i+1));
            }
     
            return linedataset;
         }
        
        public static JPanel createDemoLine(List<ArrayList<Double>> excel) {
            JFreeChart jfreechart = createChart(createDataset(excel));
            return new ChartPanel(jfreechart);
         }
        
     // 生成图表主对象JFreeChart
        public static JFreeChart createChart(DefaultCategoryDataset linedataset) {
           // 定义图表对象
           JFreeChart chart = ChartFactory.createLineChart("Cumulative rate of return", //折线图名称
             "time", // 横坐标名称
             "Value", // 纵坐标名称
             linedataset, // 数据
             PlotOrientation.VERTICAL, // 水平显示图像
             true, // include legend
             false, // tooltips
             false // urls
             );
            // chart.setBackgroundPaint(Color.red);
             
           CategoryPlot plot = chart.getCategoryPlot();
          // plot.setDomainGridlinePaint(Color.red);
           plot.setDomainGridlinesVisible(true);
           // 5,设置水平网格线颜色
          // plot.setRangeGridlinePaint(Color.blue);
           // 6,设置是否显示水平网格线
           plot.setRangeGridlinesVisible(true);
           plot.setRangeGridlinesVisible(true); //是否显示格子线
           //plot.setBackgroundAlpha(f); //设置背景透明度
           
           NumberAxis rangeAxis = (NumberAxis)plot.getRangeAxis();
            
           rangeAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());
           rangeAxis.setAutoRangeIncludesZero(true);
           rangeAxis.setUpperMargin(0.20);
           rangeAxis.setLabelAngle(Math.PI / 2.0);
           rangeAxis.setAutoRange(false);
           FileOutputStream fos_jpg=null;
           try{
            fos_jpg=new FileOutputStream("D:\ok_bing.jpg");
            /*
             * 第二个参数如果为100,会报异常:
             * java.lang.IllegalArgumentException: The 'quality' must be in the range 0.0f to 1.0f
             * 限制quality必须小于等于1,把100改成 0.1f。
             */
           // ChartUtilities.writeChartAsJPEG(fos_jpg, 0.99f, chart, 600, 300, null);
            ChartUtilities.writeChartAsJPEG(fos_jpg, chart, 900, 400);
             
           }catch(Exception e){
            System.out.println("[e]"+e);
           }finally{
            try{
             fos_jpg.close();
            }catch(Exception e){
              
            }
           }
           return chart;
        }
    }

    主函数类testClass

    package it.cast;

    import java.io.IOException;
    import java.sql.Connection;
    import java.sql.DriverManager;
    import java.sql.ResultSet;
    import java.sql.Statement;
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;

    public class testClass {

        public static void main(String[] args) {
            dataBase data = new dataBase();
            //        data.createDataBase();

            try{
                Connection conn = null;
                Statement stmt = null;
                //链接数据库
                Class.forName("oracle.jdbc.driver.OracleDriver");
                String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
                String UserName = "system";    
                String password = "manager";
                conn = DriverManager.getConnection(url, UserName, password);
                stmt = conn.createStatement();

                String sql2="select * from Data2012to2015";
                ResultSet rs = stmt.executeQuery(sql2);
                //创建序列
                List<Double> openPrice = new ArrayList<Double>();
                List<Double> highPrice = new ArrayList<Double>();
                List<Double> overPrice = new ArrayList<Double>();
                List<Double> lowPrice = new ArrayList<Double>();
                List<Double> vol = new ArrayList<Double>();

                while (rs.next()){
                    openPrice.add(Double.parseDouble(rs.getString("OPENPRICE")));
                    highPrice.add(Double.parseDouble(rs.getString("HIGHPRICE")));
                    overPrice.add(Double.parseDouble(rs.getString("OVERPRICE")));
                    lowPrice.add(Double.parseDouble(rs.getString("LOWPRICE")));
                    vol.add(Double.parseDouble(rs.getString("VOL")));

                }

                Methods m = new Methods();
                double [][]dataset = m.bpTrain(overPrice, highPrice, lowPrice, openPrice, vol);
                double [][]target = new double[dataset.length][];
                for(int i = 0;i<dataset.length;i++){
                    target[i] = new double[1];
                    target[i][0] = overPrice.get(overPrice.size()-dataset.length+i);
                }
                
                
                
                String sql3="select * from Data2015to2016";
                ResultSet rs2 = stmt.executeQuery(sql3);
                //创建序列
                List<Double> openPrice2 = new ArrayList<Double>();
                List<Double> highPrice2 = new ArrayList<Double>();
                List<Double> overPrice2 = new ArrayList<Double>();
                List<Double> lowPrice2 = new ArrayList<Double>();
                List<Double> vol2 = new ArrayList<Double>();

                while (rs2.next()){
                    openPrice2.add(Double.parseDouble(rs.getString("OPENPRICE")));
                    highPrice2.add(Double.parseDouble(rs.getString("HIGHPRICE")));
                    overPrice2.add(Double.parseDouble(rs.getString("OVERPRICE")));
                    lowPrice2.add(Double.parseDouble(rs.getString("LOWPRICE")));
                    vol2.add(Double.parseDouble(rs.getString("VOL")));

                }

                Methods m2 = new Methods();
                double [][]dataset2 = m2.bpTrain(overPrice2, highPrice2, lowPrice2, openPrice2, vol2);
                double [][]target2 = new double[dataset2.length][];
                for(int i = 0;i<dataset2.length;i++){
                    target2[i] = new double[1];
                    target2[i][0] = overPrice2.get(overPrice2.size()-dataset2.length+i);
                }



                BPnet bp = new BPnet(new int[]{6,13,13,1}, 0.15, 0.8);
                //迭代训练5000次
                for(int n=0;n<50000;n++)
                    for(int i=0;i<dataset.length;i++)
                        bp.train(dataset[i], target[i]);


                //测试数据集
                double []result = new double[dataset2.length];
                List<ArrayList<Double>>resultList = new ArrayList<ArrayList<Double>>();
                for(int j=0;j<dataset2.length;j++){
                    double []a = bp.computeOut(dataset2[j]);
                    ArrayList<Double>list = new ArrayList<Double>();
                    result[j] = 100*(-Math.log(1/a[0]-1));
                    list.add(result[j]);
                    list.add(target2[j][0]);
                    resultList.add(list);
                    System.out.println(Arrays.toString(dataset2[j])+":"+result[j]+" real:"+target2[j][0]);
                }
                //new Graph("1",resultList);
                
                BackProce b = new BackProce();
                double account = 10000;
                List<ArrayList<Double>>chance = b.selectChance(resultList,account);
                new Graph("1",chance);
                
                
                
                


            }catch (Exception e) {
                e.printStackTrace();
                // TODO: handle exception
            }
            System.out.println("End");
        }


    }

    缺点:1、只能绘制基本图像,没有找到方法将特殊点标出,如:能够获取在什么时间点买入,但是不知怎么在特定点用其他颜色标出。

    2、神经网络模型对训练数据拟合很好,但是对测试数据拟合效果不佳,猜测原因可能是过拟合或是有些其他主要的变量因素没有考虑进去。

  • 相关阅读:
    9月9
    JavaScript语法(三)
    JavaScript语法(二)
    实现AJAX的基本步骤 。。转
    Ajax 完整教程。。转载
    Struts2中的Action类(解耦方式,耦合方式)
    web应用中使用JavaMail发送邮件 。。转载
    Struts2下的<result>中的type整理
    Struts2整理+课堂代码+注意事项
    一对多,多对一,注意事项总结
  • 原文地址:https://www.cnblogs.com/yunerlalala/p/6187770.html
Copyright © 2011-2022 走看看