首页 > 代码库 > 利用神经网络预测股票收盘价(含源代码)

利用神经网络预测股票收盘价(含源代码)

攒了几天,发一个大的

这是前几天投了一家量化分析职位,他给的题目的是写神经网络择时模型,大概就是用神经网络预测收盘价

database:该类用于获得新浪网中的数据,并将其放入本地数据库。在本地数据库中建立两个表,分别是Data2012to2015和Data2015to2016,表中都含有日期,当日开盘价、当日收盘价、当日最高价、当日最低价。Data2012to2015为训练数据集,Data2015to2016为测试数据集。

package it.cast;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;

public class dataBase {
    //创建训练集:Data2012to2015和测试集Data2015to2016
    public  void createDataBase() {
    try {
        Connection conn = null;
        Statement stmt = null;
        //链接数据库
        Class.forName("oracle.jdbc.driver.OracleDriver");
        String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
        String UserName = "system";    
        String password = "manager";
        conn = DriverManager.getConnection(url, UserName, password);
        stmt = conn.createStatement();
        historyShare( conn, stmt);
        
    }catch (Exception e) {

            e.printStackTrace();

        }
    }

    private void historyShare( Connection conn, Statement stmt)
            throws SQLException, MalformedURLException, IOException,
            UnsupportedEncodingException {
        //创建表格
        //表格列为:股票id号、日期、开盘价、最高价、收盘价、最低价、成交量
        String sql = "create table Data2015to2016(stokeid integer not null primary key ," +
                "data varchar2(20), openPrice varchar2(20), highPrice varchar2(20), overPrice varchar2(20),lowPrice varchar2(20)," +
                "vol varchar2(20))";

        stmt.executeUpdate(sql);


        URL ur = null;
        ur = new URL("http://biz.finance.sina.com.cn/stock/flash_hq/kline_data.php?&rand=random(10000)&symbol=sz000001&end_date=20161118&begin_date=20151118&type=plain");

        HttpURLConnection uc = (HttpURLConnection) ur.openConnection();

        BufferedReader reader = new BufferedReader(new InputStreamReader(ur.openStream(),"GBK"));
        String line;
        PreparedStatement stmt1 = null;
        int i=1;
        //插入数据
        while((line = reader.readLine()) != null){
            //普通股票
            String sql1 = "insert into Data2015to2016 values(?,?, ?, ?, ?, ?, ?)";
            stmt1 = conn.prepareStatement(sql1);
            String[] data=http://www.mamicode.com/line.split(",");
            String date = data[0];
            String openPrice = data[1];
            String highPrice = data[2];
            String overPrice = data[3];
            String lowPrice = data[4];
            stmt1.setInt(1, i++);
            stmt1.setString(2, data[0]);
            stmt1.setString(3, data[1]);
            stmt1.setString(4, data[2]);
            stmt1.setString(5, data[3]);
            stmt1.setString(6, data[4]);
            stmt1.setString(7, data[5]);
            stmt1.executeUpdate();
            stmt1.close();
        }
    }

}

Methods:由于java没有现成的包可以直接得出某只股票的波动率指标、短期和长期均线指标等指标,由于一些指标在网上没有找到,例如动量和反转指标:REVS5,就用了动量指标MTM。所以在百度百科等资料中搜集了一些公式, 分别对这些公式编写代码,就能观测到的数据来说,是准确的。

最后采用了8个指标,分别是波动率指标:EMV;短期和长期均线指标:EMA5和EMA60,MA5和MA60;动量指标MTM;量能指标:MACD;能量指标:CR5.以这8个指标为自变量,收盘价为因变量建立神经网络模型。

package it.cast;

import java.util.ArrayList;
import java.util.List;

public class Methods {
    
    

    //搜狗百科:A=(今日最高+今日最低)/2;B=(前日最高+前日最低)/2;C=今日最高-今日最低;2.EM=(A-B)*C/今日成交额;3.EMV=N日内EM的累和;4.MAEMV=EMV的M日简单移动平均.参数N为14,参数M为9
    public List<Double> EMV(List<Double>highPrice,List<Double>lowPrice,List<Double>vol){
        List<Double>EM = new ArrayList<Double>();
        for(int i = 2;i<highPrice.size();i++){
            double A = (highPrice.get(i)+lowPrice.get(i))/2;
            double B = (highPrice.get(i-2)+lowPrice.get(i-2))/2;
            double C = highPrice.get(i)-lowPrice.get(i);
            EM.add(((A-B)*C)/vol.get(i));
        }

        List<Double>EMV = new ArrayList<Double>();
        //取N为14,即14日的EM值之和;M为9,即9日的移动平均
        int N = 14;
        int M = 9;
        for(int i = N;i<EM.size()+1;i++){
            //14日累和
            double sum = 0;
            for(int j = i-N;j<i;j++){
                sum += EM.get(j);
            }
            EMV.add(sum);
        }

        List<Double>MAEMV = new ArrayList<Double>();
        for(int i = M;i<EMV.size()+1;i++){
            //9日移动平均
            double sum = 0;
            for(int j = i-M;j<i;j++){
                sum += EMV.get(j);
            }
            sum = sum/M;
            MAEMV.add(sum);
        }
        return MAEMV;
    }

    //EMA=(当日或当期收盘价-上一日或上期EXPMA)/N+上一日或上期EXPMA,其中,首次上期EXPMA值为上一期收盘价,N为天数。
    public List<Double> EMA5(List<Double>overPrice){
        //取20121118年收盘价为初始EXPMA
        List<Double>EMA5 = new ArrayList<Double>();
        for(int i = 0;i<5;i++){
            EMA5.add(overPrice.get(i));
        }
        for(int i = 5;i<overPrice.size();i++){
            EMA5.add((overPrice.get(i)-EMA5.get(i-5))/5+EMA5.get(i-5));

        }
        return EMA5;
    }


    public List<Double> EMA60(List<Double>overPrice){
        //取20121118年收盘价为初始EXPMA
        List<Double>EMA60 = new ArrayList<Double>();
        for(int i = 0;i<60;i++){
            EMA60.add(overPrice.get(i));
        }
        for(int i = 60;i<overPrice.size();i++){
            EMA60.add((overPrice.get(i)-EMA60.get(i-60))/60+EMA60.get(i-60));
        }
        return EMA60;
    }

    //5日均线
    public List<Double> MA5(List<Double>overPrice){
        List<Double>MA5 = new ArrayList<Double>();
        for(int i = 5;i<overPrice.size()+1;i++){
            double sum = 0;
            for(int j = i-1;j>=i-5;j--){
                sum += overPrice.get(j);
            }
            sum = sum/5;
            MA5.add(sum);
        }
        return MA5;
    }


    //60日均线
    public List<Double> MA60(List<Double>overPrice){
        List<Double>MA60 = new ArrayList<Double>();
        for(int i = 60;i<overPrice.size()+1;i++){
            double sum = 0;
            for(int j = i-1;j>=i-60;j--){
                sum += overPrice.get(j);
            }
            sum = sum/60;
            MA60.add(sum);
        }
        return MA60;
    }

    //动量指标MTM,1.MTM=当日收盘价-N日前收盘价;2.MTMMA=MTM的M日移动平均;3.参数N一般设置为12日参数M一般设置为6,表中当动量值减低或反转增加时,应为买进或卖出时机
    public List<Double> MTM(List<Double>overPrice){
        List<Double>MTM = new ArrayList<Double>();
        List<Double>MTMlist = new ArrayList<Double>();
        int N = 12;
        int M = 6;
        for(int i = 12;i<overPrice.size();i++){
            MTM.add(overPrice.get(i)-overPrice.get(i-12));
        }
        
        //移动平均参数为6
        for(int i = 6;i<MTM.size()+1;i++){
            double sum = 0;
            for(int j = i-1;j>=i-6;j--){
                sum += MTM.get(j);
            }
            sum = sum/6;
            MTMlist.add(sum);
        }
        return MTMlist;
    }
    
    
    //百度百科:http://baike.baidu.com/link?url=XQf2I-JIyNR1AEM_EnMnuU90U1vmJDoXukUe1fQVsBA1Y_fqAA8dj7DoxLCoh5U-YysBkVT5aIZLXeG2g1snoK:量能指标就是通过动态分析成交量的变化,
    public List<Double> MACD(List<Double>vol){
        int shortN = 12;
        List<Double>Short = new ArrayList<Double>();
        for(int i = shortN;i<vol.size()+1;i++){
            Short.add(2*vol.get(i-1)+(shortN-1)*vol.get(i-shortN));
        }
        int longN = 26;
        List<Double>Long = new ArrayList<Double>();
        for(int i = longN;i<vol.size()+1;i++){
            Long.add(2*vol.get(i-1)+(longN-1)*vol.get(i-longN));
        }
        
        //    取两个序列中较短序列的长度
        int length = 0;
        if(Short.size()>Long.size()){
            length = Long.size();
        }else{
            length = Short.size();
        }
        
        List<Double>DIFF1 = new ArrayList<Double>();
        for(int i = length-1;i>=0;i--){
            DIFF1.add(Short.get(i)-Long.get(i));
        }
        List<Double>DIFF = new ArrayList<Double>();
        for(int i = 0;i<DIFF1.size();i++){
            DIFF.add(DIFF1.get(DIFF1.size()-i-1));
        }
        List<Double>DEA = new ArrayList<Double>();
        for(int i = 0;i<DIFF.size()-1;i++){
            DEA.add(2*DIFF.get(i+1)+(9-1)*DIFF.get(i));
        }
        
        List<Double>MACD = new ArrayList<Double>();
        for(int i = 1;i<DIFF.size();i++){
            MACD.add(DIFF.get(i)-DEA.get(i-1));
        }
        return MACD;
    }
    
    
    //能量指标:CR,见百度百科:http://baike.baidu.com/link?url=v5yYFep6wZioav0P-LOruuhkzjho6PqzQqfEBj5TYQLfaadLSADSQVl0njP7k1zY78KJMoBFrE4OO4wYolZXbMnRRQi7U66R0X2jeSV3ZoXKeuG2zEbqEqP4CnyiF7j6
    public List<Double> CR5(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice){
        List<Double> YM = new ArrayList<Double>();
        List<Double> HYM = new ArrayList<Double>();
        List<Double> YML = new ArrayList<Double>();
        List<Double> CR = new ArrayList<Double>();
        for(int i = 0;i<overPrice.size();i++){
            YM.add((highPrice.get(i)+overPrice.get(i)+lowPrice.get(i)+openPrice.get(i))/4);
        }
        //p1表示5日以来多方力量总和,p2表示5日以来空方力量总和
        for(int i = 6;i<highPrice.size()+1;i++){
            double sum = 0;
            for(int j = i-1;j>=i-5;j--){
                sum += highPrice.get(j)-YM.get(j-1);
            }
            HYM.add(sum);
        }
        //p2表示5日以来空方力量总和,p2表示5日以来空方力量总和
        for(int i = 6;i<lowPrice.size()+1;i++){
            double sum = 0;
            for(int j = i-1;j>=i-5;j--){
                sum += YM.get(j-1)-lowPrice.get(j);
            }
            YML.add(sum);
        }
        for(int i = 0;i<YML.size();i++){
            double temp = (double)HYM.get(i)/YML.get(i);
            if(temp<0){
                CR.add((double) 0);
            }else{
                CR.add(temp);
            }
            
        }
        return CR;
                

    }
    
    public double[][] bpTrain(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice,List<Double>vol){
        List<Double>EMV = EMV(highPrice, lowPrice, vol);
        List<Double>EMA5 = EMA5(overPrice);
        List<Double>EMA60 = EMA60(overPrice);
        List<Double>MA5 = MA5(overPrice);
        List<Double>MA60 = MA60(overPrice);
        List<Double>MTM = MTM(overPrice);
        List<Double>MACD = MACD(vol);
        List<Double>CR5 = CR5(overPrice, highPrice, lowPrice, openPrice);
        
        int length = 0;
        if(EMA60.size()>MA60.size()){
            length = MA60.size();
        }else{
            length = EMA60.size();
        }
        List<ArrayList<Double>>datalist = new ArrayList<ArrayList<Double>>();
        for(int i = 0;i<length;i++){
            ArrayList<Double>list = new ArrayList<Double>();
            //list.add(EMV.get(EMV.size()-length+i));
            list.add(EMA5.get(EMA5.size()-length+i));
            list.add(EMA60.get(EMA60.size()-length+i));
            list.add(MA5.get(MA5.size()-length+i));
            list.add(MA60.get(MA60.size()-length+i));
            list.add(MTM.get(MTM.size()-length+i));
    //        list.add(MACD.get(MACD.size()-length+i));
            list.add(CR5.get(CR5.size()-length+i));
            datalist.add(list);
        }
        double [][]data = http://www.mamicode.com/new double[datalist.size()][6];
        for(int i = 0;i<datalist.size();i++){
            for(int j = 0;j<6;j++){
                data[i][j] = datalist.get(i).get(j);
                System.out.print(data[i][j]+"  ");
            }
            System.out.println();
        }
        return data;
    }
    
}

 

 

BPnet:这里想建立输入单元为8个,两层隐含层,每个隐含层为13个单元,输出层单元为1的神经网络。

首先初始化输入层到隐含层,隐含层之间,以及隐含层到输出层的权重矩阵;

其次利用权重矩阵和输入层分别计算出每个隐含层节点数据

之后利用计算得出的输出层数据与真实值进行比较,并逐层调节权重;

反复上述过程直至精度达到要求或是达到迭代次数的要求;

这里设置迭代次数为5000次;

利用的测试数据集为Data2012to2015

下图为训练之后的模型对Data2012to2015自身进行拟合的效果:(这里由于自变量大概是10左右的数据,所以在利用激活函数1/(1+e^-ax))时,a取了0.01

package it.cast;

import java.util.Random;

public class BPnet {
    public double[][] layer;//神经网络各层节点
    public double[][] layerErr;//神经网络各节点误差
    public double[][][] layer_weight;//各层节点权重
    public double[][][] layer_weight_delta;//各层节点权重动量
    public double mobp;//动量系数
    public double rate;//学习系数

    public BPnet(int[] layernum, double rate, double mobp){
        this.mobp = mobp;
        this.rate = rate;
        layer = new double[layernum.length][];
        layerErr = new double[layernum.length][];
        layer_weight = new double[layernum.length][][];
        layer_weight_delta = new double[layernum.length][][];
        Random random = new Random();
        for(int l=0;l<layernum.length;l++){
            layer[l]=new double[layernum[l]];
            layerErr[l]=new double[layernum[l]];
            if(l+1<layernum.length){
                layer_weight[l]=new double[layernum[l]+1][layernum[l+1]];
                layer_weight_delta[l]=new double[layernum[l]+1][layernum[l+1]];
                for(int j=0;j<layernum[l]+1;j++)
                    for(int i=0;i<layernum[l+1];i++)
                        layer_weight[l][j][i]=random.nextDouble();//随机初始化权重
            }   
        }
    }
    //逐层向前计算输出
    public double[] computeOut(double[] in){
        for(int l=1;l<layer.length;l++){
            for(int j=0;j<layer[l].length;j++){
                double z=layer_weight[l-1][layer[l-1].length][j];
                for(int i=0;i<layer[l-1].length;i++){
                    layer[l-1][i]=l==1?in[i]:layer[l-1][i];
                    z+=layer_weight[l-1][i][j]*layer[l-1][i];
                }
            //    System.out.println(z+"####");
                
                layer[l][j]=1/(1+Math.exp(-0.01*z));
            //    System.out.println("&&**"+layer[l][j]);
                
                
            }
        }
      //System.out.println("&&^^**"+layer[layer.length-1][0]);
        return layer[layer.length-1];
    }
    //逐层反向计算误差并修改权重
    public void updateWeight(double[] tar){
        int l=layer.length-1;
        for(int j=0;j<layerErr[l].length;j++)
            layerErr[l][j]=layer[l][j]*(1-layer[l][j])*(1/(1+Math.exp(-0.01*tar[j]))-layer[l][j]);

        while(l-->0){
            for(int j=0;j<layerErr[l].length;j++){
                double z = 0.0;
                for(int i=0;i<layerErr[l+1].length;i++){
                    z=z+l>0?layerErr[l+1][i]*layer_weight[l][j][i]:0;
                    layer_weight_delta[l][j][i]= mobp*layer_weight_delta[l][j][i]+rate*layerErr[l+1][i]*layer[l][j];//隐含层动量调整
                    layer_weight[l][j][i]+=layer_weight_delta[l][j][i];//隐含层权重调整
                    if(j==layerErr[l].length-1){
                        layer_weight_delta[l][j+1][i]= mobp*layer_weight_delta[l][j+1][i]+rate*layerErr[l+1][i];//截距动量调整
                        layer_weight[l][j+1][i]+=layer_weight_delta[l][j+1][i];//截距权重调整
                    }
                }
                layerErr[l][j]=z*layer[l][j]*(1-layer[l][j]);//记录误差
            }
        }
    }

    public void train(double[] in, double[] tar){
        double[] out = computeOut(in);
        updateWeight(tar);
    }
}

 

 

从图中可以看出2012年初,股市变化幅度很大时,模型拟合效果稍差,但总体拟合效果较好。(红线表示拟合曲线,蓝线表示真实收盘价)

测试数据集采用的是Data2015to2016,即2015年至2016年数据,拟合拟合效果如下:

 技术分享

从图中可以看出曲线可以拟合大致趋势,但是不能很好的拟合波动,可能是由于对训练数据集过渡拟合的原因。

 

BackProce:该类计算了如果按照神经网络模型对该股票进行操作的结果,采用的策略是,如果下一天的预测值高于当天的收盘价,就买入,低于就卖出,设置初始账户金额为10000.

可得到最后的收益率为0.18364521221914928,账户金额为:11836.452122191493。

累计收益率如下图:

 技术分享

 

累计收益率明显呈现上升趋势。

 

package it.cast;

import java.util.ArrayList;
import java.util.List;

public class BackProce {
    
    public List<ArrayList<Double>> selectChance(List<ArrayList<Double>>result,double account){
        double accountF = account;
        System.out.println("初始账户为: "+account);
        ArrayList<Double>expect = new ArrayList<Double>();
        ArrayList<Double>target = new ArrayList<Double>();
        for(int i = 0;i<result.size();i++){
            expect.add(result.get(i).get(0));
            target.add(result.get(i).get(1));
        }
        List<ArrayList<Double>>chance = new ArrayList<ArrayList<Double>>();
        for(int i = 1;i<expect.size();i++){
            if(expect.get(i)>target.get(i-1)){
                //买入
                account += account*(target.get(i)-target.get(i-1))/target.get(i-1);
            }
            ArrayList<Double>list = new ArrayList<Double>();
            list.add((account-accountF)/accountF);
            list.add((double) i);
            chance.add(list);
        }
        System.out.println("期末账户为: "+account);
        System.out.println("年化收益率为: "+(account-accountF)/accountF);
        return chance;
    }
}

辅助类Graph:该类借助了jfree包,用于绘制图像

package it.cast;

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Font;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;

import javax.swing.JPanel;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.category.LineAndShapeRenderer;
import org.jfree.chart.title.TextTitle;
import org.jfree.data.category.CategoryDataset;
import org.jfree.data.category.DefaultCategoryDataset;
import org.jfree.ui.ApplicationFrame;
import org.jfree.ui.HorizontalAlignment;
import org.jfree.ui.RectangleEdge;

public class Graph extends ApplicationFrame{
    ChartPanel frame1;  
    private static final long serialVersionUID = 1L;
    
    public Graph(String s , List<ArrayList<Double>> excel) {
       super(s);
       setContentPane(createDemoLine(excel));
    }
    
    public static DefaultCategoryDataset createDataset(List<ArrayList<Double>> excel) {
        DefaultCategoryDataset linedataset = new DefaultCategoryDataset();
        for (int i=0; i <excel.size(); i++) {
            linedataset.addValue(excel.get(i).get(0), "expect", excel.get(i).get(1));
            //linedataset.addValue(excel.get(i).get(1), "target", Integer.toString(i+1));
        }
 
        return linedataset;
     }
    
    public static JPanel createDemoLine(List<ArrayList<Double>> excel) {
        JFreeChart jfreechart = createChart(createDataset(excel));
        return new ChartPanel(jfreechart);
     }
    
 // 生成图表主对象JFreeChart
    public static JFreeChart createChart(DefaultCategoryDataset linedataset) {
       // 定义图表对象
       JFreeChart chart = ChartFactory.createLineChart("Cumulative rate of return", //折线图名称
         "time", // 横坐标名称
         "Value", // 纵坐标名称
         linedataset, // 数据
         PlotOrientation.VERTICAL, // 水平显示图像
         true, // include legend
         false, // tooltips
         false // urls
         );
        // chart.setBackgroundPaint(Color.red);
         
       CategoryPlot plot = chart.getCategoryPlot();
      // plot.setDomainGridlinePaint(Color.red);
       plot.setDomainGridlinesVisible(true);
       // 5,设置水平网格线颜色
      // plot.setRangeGridlinePaint(Color.blue);
       // 6,设置是否显示水平网格线
       plot.setRangeGridlinesVisible(true);
       plot.setRangeGridlinesVisible(true); //是否显示格子线
       //plot.setBackgroundAlpha(f); //设置背景透明度
       
       NumberAxis rangeAxis = (NumberAxis)plot.getRangeAxis();
        
       rangeAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());
       rangeAxis.setAutoRangeIncludesZero(true);
       rangeAxis.setUpperMargin(0.20);
       rangeAxis.setLabelAngle(Math.PI / 2.0);
       rangeAxis.setAutoRange(false);
       FileOutputStream fos_jpg=null;
       try{
        fos_jpg=new FileOutputStream("D:\\ok_bing.jpg");
        /*
         * 第二个参数如果为100,会报异常:
         * java.lang.IllegalArgumentException: The ‘quality‘ must be in the range 0.0f to 1.0f
         * 限制quality必须小于等于1,把100改成 0.1f。
         */
       // ChartUtilities.writeChartAsJPEG(fos_jpg, 0.99f, chart, 600, 300, null);
        ChartUtilities.writeChartAsJPEG(fos_jpg, chart, 900, 400);
         
       }catch(Exception e){
        System.out.println("[e]"+e);
       }finally{
        try{
         fos_jpg.close();
        }catch(Exception e){
          
        }
       }
       return chart;
    }
}

主函数类testClass

package it.cast;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class testClass {

    public static void main(String[] args) {
        dataBase data = http://www.mamicode.com/new dataBase();
        //        data.createDataBase();

        try{
            Connection conn = null;
            Statement stmt = null;
            //链接数据库
            Class.forName("oracle.jdbc.driver.OracleDriver");
            String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
            String UserName = "system";    
            String password = "manager";
            conn = DriverManager.getConnection(url, UserName, password);
            stmt = conn.createStatement();

            String sql2="select * from Data2012to2015";
            ResultSet rs = stmt.executeQuery(sql2);
            //创建序列
            List<Double> openPrice = new ArrayList<Double>();
            List<Double> highPrice = new ArrayList<Double>();
            List<Double> overPrice = new ArrayList<Double>();
            List<Double> lowPrice = new ArrayList<Double>();
            List<Double> vol = new ArrayList<Double>();

            while (rs.next()){
                openPrice.add(Double.parseDouble(rs.getString("OPENPRICE")));
                highPrice.add(Double.parseDouble(rs.getString("HIGHPRICE")));
                overPrice.add(Double.parseDouble(rs.getString("OVERPRICE")));
                lowPrice.add(Double.parseDouble(rs.getString("LOWPRICE")));
                vol.add(Double.parseDouble(rs.getString("VOL")));

            }

            Methods m = new Methods();
            double [][]dataset = m.bpTrain(overPrice, highPrice, lowPrice, openPrice, vol);
            double [][]target = new double[dataset.length][];
            for(int i = 0;i<dataset.length;i++){
                target[i] = new double[1];
                target[i][0] = overPrice.get(overPrice.size()-dataset.length+i);
            }
            
            
            
            String sql3="select * from Data2015to2016";
            ResultSet rs2 = stmt.executeQuery(sql3);
            //创建序列
            List<Double> openPrice2 = new ArrayList<Double>();
            List<Double> highPrice2 = new ArrayList<Double>();
            List<Double> overPrice2 = new ArrayList<Double>();
            List<Double> lowPrice2 = new ArrayList<Double>();
            List<Double> vol2 = new ArrayList<Double>();

            while (rs2.next()){
                openPrice2.add(Double.parseDouble(rs.getString("OPENPRICE")));
                highPrice2.add(Double.parseDouble(rs.getString("HIGHPRICE")));
                overPrice2.add(Double.parseDouble(rs.getString("OVERPRICE")));
                lowPrice2.add(Double.parseDouble(rs.getString("LOWPRICE")));
                vol2.add(Double.parseDouble(rs.getString("VOL")));

            }

            Methods m2 = new Methods();
            double [][]dataset2 = m2.bpTrain(overPrice2, highPrice2, lowPrice2, openPrice2, vol2);
            double [][]target2 = new double[dataset2.length][];
            for(int i = 0;i<dataset2.length;i++){
                target2[i] = new double[1];
                target2[i][0] = overPrice2.get(overPrice2.size()-dataset2.length+i);
            }



            BPnet bp = new BPnet(new int[]{6,13,13,1}, 0.15, 0.8);
            //迭代训练5000次
            for(int n=0;n<50000;n++)
                for(int i=0;i<dataset.length;i++)
                    bp.train(dataset[i], target[i]);


            //测试数据集
            double []result = new double[dataset2.length];
            List<ArrayList<Double>>resultList = new ArrayList<ArrayList<Double>>();
            for(int j=0;j<dataset2.length;j++){
                double []a = bp.computeOut(dataset2[j]);
                ArrayList<Double>list = new ArrayList<Double>();
                result[j] = 100*(-Math.log(1/a[0]-1));
                list.add(result[j]);
                list.add(target2[j][0]);
                resultList.add(list);
                System.out.println(Arrays.toString(dataset2[j])+":"+result[j]+" real:"+target2[j][0]);
            }
            //new Graph("1",resultList);
            
            BackProce b = new BackProce();
            double account = 10000;
            List<ArrayList<Double>>chance = b.selectChance(resultList,account);
            new Graph("1",chance);
            
            
            
            


        }catch (Exception e) {
            e.printStackTrace();
            // TODO: handle exception
        }
        System.out.println("End");
    }


}

 

缺点:1、只能绘制基本图像,没有找到方法将特殊点标出,如:能够获取在什么时间点买入,但是不知怎么在特定点用其他颜色标出。

2、神经网络模型对训练数据拟合很好,但是对测试数据拟合效果不佳,猜测原因可能是过拟合或是有些其他主要的变量因素没有考虑进去。

 

利用神经网络预测股票收盘价(含源代码)