zoukankan      html  css  js  c++  java
  • JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测。

    简介

    BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法进行权值与阈值的调整。在20世纪80年代,几位不同的学者分别开发出了用于训练多层感知机的反向传播算法,David Rumelhart和James McClelland提出的反向传播算法是最具影响力的。其包含BP的两大主要过程,即工作信号的正向传播与误差信号的反向传播,分别负责了神经网络中输出的计算与权值和阈值更新。工作信号的正向传播是通过计算得到BP神经网络的实际输出,误差信号的反向传播是由后往前逐层修正权值与阈值,为了使实际输出更接近期望输出。

    ​ (1)工作信号正向传播。输入信号从输入层进入,通过突触进入隐含层神经元,经传递函数运算后,传递到输出层,并且在输出层计算出输出信号传出。当工作信号正向传播时,权值与阈值固定不变,神经网络中每层的状态只与前一层的净输出、权值和阈值有关。若正向传播在输出层获得到期望的输出,则学习结束,并保留当前的权值与阈值;若正向传播在输出层得不到期望的输出,则在误差信号的反向传播中修正权值与阈值。

    ​ (2)误差信号反向传播。在工作信号正向传播后若得不到期望的输出,则通过计算误差信号进行反向传播,通过计算BP神经网络的实际输出与期望输出之间的差值作为误差信号,并且由神经网络的输出层,逐层向输入层传播。在此过程中,每向前传播一层,就对该层的权值与阈值进行修改,由此一直向前传播直至输入层,该过程是为了使神经网络的结果与期望的结果更相近。

    ​ 当进行一次正向传播和反向传播后,若误差仍不能达到要求,则该过程继续下去,直至误差满足精度,或者满足迭代次数等其他设置的结束条件。

    推导请见 https://zh.wikipedia.org/wiki/%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E7%AE%97%E6%B3%95

    BPNN结构

    该BPNN为单输入层单隐含层单输出层结构

    项目结构

    • ActivationFunction:激活函数的接口
    • BPModel:BP模型实体类
    • BPNeuralNetworkFactory:BP神经网络工厂,包括训练BP神经网络,计算,序列化等功能
    • BPParameter:BP神经网络参数实体类
    • Matrix:矩阵实体类
    • Sigmoid:Sigmoid传输函数,实现了ActivationFunction接口

    实现代码

    Matrix实体类

    模拟了矩阵的基本运算方法。

    import java.io.Serializable;
    
    public class Matrix implements Serializable {
        private double[][] matrix;
        //矩阵列数
        private int matrixColNums;
        //矩阵行数
        private int matrixRowNums;
    
        /**
         * 构造一个空矩阵
         */
        public Matrix() {
            this.matrix = null;
            this.matrixColNums = 0;
            this.matrixRowNums = 0;
        }
    
        /**
         * 构造一个matrix矩阵
         * @param matrix
         */
        public Matrix(double[][] matrix) {
            this.matrix = matrix;
            this.matrixRowNums = matrix.length;
            this.matrixColNums = matrix[0].length;
        }
    
        /**
         * 构造一个rowNums行colNums列值为0的矩阵
         * @param rowNums
         * @param colNums
         */
        public Matrix(int rowNums,int colNums) {
            double[][] matrix = new double[rowNums][colNums];
            for (int i = 0; i < rowNums; i++) {
                for (int j = 0; j < colNums; j++) {
                    matrix[i][j] = 0;
                }
            }
            this.matrix = matrix;
            this.matrixRowNums = rowNums;
            this.matrixColNums = colNums;
        }
    
        /**
         * 构造一个rowNums行colNums列值为val的矩阵
         * @param val
         * @param rowNums
         * @param colNums
         */
        public Matrix(double val,int rowNums,int colNums) {
            double[][] matrix = new double[rowNums][colNums];
            for (int i = 0; i < rowNums; i++) {
                for (int j = 0; j < colNums; j++) {
                    matrix[i][j] = val;
                }
            }
            this.matrix = matrix;
            this.matrixRowNums = rowNums;
            this.matrixColNums = colNums;
        }
    
        public double[][] getMatrix() {
            return matrix;
        }
    
        public void setMatrix(double[][] matrix) {
            this.matrix = matrix;
            this.matrixRowNums = matrix.length;
            this.matrixColNums = matrix[0].length;
        }
    
        public int getMatrixColNums() {
            return matrixColNums;
        }
    
        public int getMatrixRowNums() {
            return matrixRowNums;
        }
    
        /**
         * 获取矩阵指定位置的值
         *
         * @param x
         * @param y
         * @return
         */
        public double getValOfIdx(int x, int y) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (x > matrixRowNums - 1) {
                throw new Exception("索引x越界");
            }
            if (y > matrixColNums - 1) {
                throw new Exception("索引y越界");
            }
            return matrix[x][y];
        }
    
        /**
         * 获取矩阵指定行
         *
         * @param x
         * @return
         */
        public Matrix getRowOfIdx(int x) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (x > matrixRowNums - 1) {
                throw new Exception("索引x越界");
            }
            double[][] result = new double[1][matrixColNums];
            result[0] = matrix[x];
            return new Matrix(result);
        }
    
        /**
         * 获取矩阵指定列
         *
         * @param y
         * @return
         */
        public Matrix getColOfIdx(int y) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (y > matrixColNums - 1) {
                throw new Exception("索引y越界");
            }
            double[][] result = new double[matrixRowNums][1];
            for (int i = 0; i < matrixRowNums; i++) {
                result[i][1] = matrix[i][y];
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵乘矩阵
         *
         * @param a
         * @return
         * @throws Exception
         */
        public Matrix multiple(Matrix a) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (a.getMatrix() == null) {
                throw new Exception("参数矩阵为空");
            }
            if (matrixColNums != a.getMatrixRowNums()) {
                throw new Exception("矩阵纬度不同,不可计算");
            }
            double[][] result = new double[matrixRowNums][a.getMatrixColNums()];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < a.getMatrixColNums(); j++) {
                    for (int k = 0; k < matrixColNums; k++) {
                        result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j];
                    }
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 二维数组乘一个数字
         *
         * @param a
         * @return
         */
        public Matrix multiple(double a) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            double[][] result = new double[matrixRowNums][matrixColNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[i][j] = matrix[i][j] * a;
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵点乘
         *
         * @param a
         * @return
         */
        public Matrix pointMultiple(Matrix a) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (a.getMatrix() == null) {
                throw new Exception("参数矩阵为空");
            }
            if (matrixRowNums != a.getMatrixRowNums() && matrixColNums != a.getMatrixColNums()) {
                throw new Exception("矩阵纬度不同,不可计算");
            }
            double[][] result = new double[matrixRowNums][matrixColNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[i][j] = matrix[i][j] * a.getMatrix()[i][j];
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵加法
         *
         * @param a
         * @return
         */
        public Matrix plus(Matrix a) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (a.getMatrix() == null) {
                throw new Exception("参数矩阵为空");
            }
            if (matrixRowNums != a.getMatrixRowNums() && matrixColNums != a.getMatrixColNums()) {
                throw new Exception("矩阵纬度不同,不可计算");
            }
            double[][] result = new double[matrixRowNums][matrixColNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[i][j] = matrix[i][j] + a.getMatrix()[i][j];
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵减法
         *
         * @param a
         * @return
         */
        public Matrix subtract(Matrix a) throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            if (a.getMatrix() == null) {
                throw new Exception("参数矩阵为空");
            }
            if (matrixRowNums != a.getMatrixRowNums() && matrixColNums != a.getMatrixColNums()) {
                throw new Exception("矩阵纬度不同,不可计算");
            }
            double[][] result = new double[matrixRowNums][matrixColNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[i][j] = matrix[i][j] - a.getMatrix()[i][j];
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵行求和
         *
         * @return
         */
        public Matrix sumRow() throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            double[][] result = new double[matrixRowNums][1];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[i][1] += matrix[i][j];
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵列求和
         *
         * @return
         */
        public Matrix sumCol() throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            double[][] result = new double[1][matrixColNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[0][i] += matrix[i][j];
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵所有元素求和
         *
         * @return
         */
        public double sumAll() throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            double result = 0;
            for (double[] doubles : matrix) {
                for (int j = 0; j < matrixColNums; j++) {
                    result += doubles[j];
                }
            }
            return result;
        }
    
        /**
         * 矩阵所有元素求平方
         *
         * @return
         */
        public Matrix square() throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            double[][] result = new double[matrixRowNums][matrixColNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[i][j] = matrix[i][j] * matrix[i][j];
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 矩阵转置
         *
         * @return
         */
        public Matrix transpose() throws Exception {
            if (matrix == null) {
                throw new Exception("矩阵为空");
            }
            double[][] result = new double[matrixColNums][matrixRowNums];
            for (int i = 0; i < matrixRowNums; i++) {
                for (int j = 0; j < matrixColNums; j++) {
                    result[j][i] = matrix[i][j];
                }
            }
            return new Matrix(result);
        }
    
        @Override
        public String toString() {
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.append("
    ");
            for (int i = 0; i < matrixRowNums; i++) {
                stringBuilder.append("# ");
                for (int j = 0; j < matrixColNums; j++) {
                    stringBuilder.append(matrix[i][j]).append("	 ");
                }
                stringBuilder.append("#
    ");
            }
            stringBuilder.append("
    ");
            return stringBuilder.toString();
        }
    }
    Matrix代码

    ActivationFunction接口

    public interface ActivationFunction {
        //计算值
        double computeValue(double val);
        //计算导数
        double computeDerivative(double val);
    }
    ActivationFunction代码

    Sigmoid

    import java.io.Serializable;
    
    public class Sigmoid implements ActivationFunction, Serializable {
        @Override
        public double computeValue(double val) {
            return 1 / (1 + Math.exp(-val));
        }
    
        @Override
        public double computeDerivative(double val) {
            return computeValue(val) * (1 - computeValue(val));
        }
    }
    Sigmoid代码

    BPParameter

    包含了BP神经网络训练所需的参数

    import java.io.Serializable;
    
    public class BPParameter implements Serializable {
    
        //输入层神经元个数
        private int inputLayerNeuronNum = 3;
        //隐含层神经元个数
        private int hiddenLayerNeuronNum = 3;
        //输出层神经元个数
        private int outputLayerNeuronNum = 1;
        //归一化区间
        private double normalizationMin = 0.2;
        private double normalizationMax = 0.8;
        //学习步长
        private double step = 0.05;
        //动量因子
        private double momentumFactor = 0.2;
        //激活函数
        private ActivationFunction activationFunction = new Sigmoid();
        //精度
        private double precision = 0.000001;
        //最大循环次数
        private int maxTimes = 1000000;
    
        public double getMomentumFactor() {
            return momentumFactor;
        }
    
        public void setMomentumFactor(double momentumFactor) {
            this.momentumFactor = momentumFactor;
        }
    
        public double getStep() {
            return step;
        }
    
        public void setStep(double step) {
            this.step = step;
        }
    
        public double getNormalizationMin() {
            return normalizationMin;
        }
    
        public void setNormalizationMin(double normalizationMin) {
            this.normalizationMin = normalizationMin;
        }
    
        public double getNormalizationMax() {
            return normalizationMax;
        }
    
        public void setNormalizationMax(double normalizationMax) {
            this.normalizationMax = normalizationMax;
        }
    
        public int getInputLayerNeuronNum() {
            return inputLayerNeuronNum;
        }
    
        public void setInputLayerNeuronNum(int inputLayerNeuronNum) {
            this.inputLayerNeuronNum = inputLayerNeuronNum;
        }
    
        public int getHiddenLayerNeuronNum() {
            return hiddenLayerNeuronNum;
        }
    
        public void setHiddenLayerNeuronNum(int hiddenLayerNeuronNum) {
            this.hiddenLayerNeuronNum = hiddenLayerNeuronNum;
        }
    
        public int getOutputLayerNeuronNum() {
            return outputLayerNeuronNum;
        }
    
        public void setOutputLayerNeuronNum(int outputLayerNeuronNum) {
            this.outputLayerNeuronNum = outputLayerNeuronNum;
        }
    
        public ActivationFunction getActivationFunction() {
            return activationFunction;
        }
    
        public void setActivationFunction(ActivationFunction activationFunction) {
            this.activationFunction = activationFunction;
        }
    
        public double getPrecision() {
            return precision;
        }
    
        public void setPrecision(double precision) {
            this.precision = precision;
        }
    
        public int getMaxTimes() {
            return maxTimes;
        }
    
        public void setMaxTimes(int maxTimes) {
            this.maxTimes = maxTimes;
        }
    }
    BPParameter代码

    BPModel

    BP神经网络模型,包括权值与阈值及训练参数等属性

    import java.io.Serializable;
    
    public class BPModel implements Serializable {
        //BP神经网络权值与阈值
        private Matrix weightIJ;
        private Matrix b1;
        private Matrix weightJP;
        private Matrix b2;
        /*用于反归一化*/
        private Matrix inputMax;
        private Matrix inputMin;
        private Matrix outputMax;
        private Matrix outputMin;
        /*BP神经网络训练参数*/
        private BPParameter bpParameter;
        /*BP神经网络训练情况*/
        private double error;
        private int times;
    
        public Matrix getWeightIJ() {
            return weightIJ;
        }
    
        public void setWeightIJ(Matrix weightIJ) {
            this.weightIJ = weightIJ;
        }
    
        public Matrix getB1() {
            return b1;
        }
    
        public void setB1(Matrix b1) {
            this.b1 = b1;
        }
    
        public Matrix getWeightJP() {
            return weightJP;
        }
    
        public void setWeightJP(Matrix weightJP) {
            this.weightJP = weightJP;
        }
    
        public Matrix getB2() {
            return b2;
        }
    
        public void setB2(Matrix b2) {
            this.b2 = b2;
        }
    
        public Matrix getInputMax() {
            return inputMax;
        }
    
        public void setInputMax(Matrix inputMax) {
            this.inputMax = inputMax;
        }
    
        public Matrix getInputMin() {
            return inputMin;
        }
    
        public void setInputMin(Matrix inputMin) {
            this.inputMin = inputMin;
        }
    
        public Matrix getOutputMax() {
            return outputMax;
        }
    
        public void setOutputMax(Matrix outputMax) {
            this.outputMax = outputMax;
        }
    
        public Matrix getOutputMin() {
            return outputMin;
        }
    
        public void setOutputMin(Matrix outputMin) {
            this.outputMin = outputMin;
        }
    
        public BPParameter getBpParameter() {
            return bpParameter;
        }
    
        public void setBpParameter(BPParameter bpParameter) {
            this.bpParameter = bpParameter;
        }
    
        public double getError() {
            return error;
        }
    
        public void setError(double error) {
            this.error = error;
        }
    
        public int getTimes() {
            return times;
        }
    
        public void setTimes(int times) {
            this.times = times;
        }
    }
    BPModel代码

    BPNeuralNetworkFactory

    BP神经网络工厂,包含了BP神经网络训练等功能

    import java.io.*;
    import java.util.*;
    
    public class BPNeuralNetworkFactory {
        /**
         * 训练BP神经网络模型
         * @param bpParameter
         * @param inputAndOutput
         * @return
         */
        public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception {
            //BP神经网络的输出
            BPModel result = new BPModel();
            result.setBpParameter(bpParameter);
    
            ActivationFunction activationFunction = bpParameter.getActivationFunction();
            int inputNum = bpParameter.getInputLayerNeuronNum();
            int hiddenNum = bpParameter.getHiddenLayerNeuronNum();
            int outputNum = bpParameter.getOutputLayerNeuronNum();
            double normalizationMin = bpParameter.getNormalizationMin();
            double normalizationMax = bpParameter.getNormalizationMax();
            double step = bpParameter.getStep();
            double momentumFactor = bpParameter.getMomentumFactor();
            double precision = bpParameter.getPrecision();
            int maxTimes = bpParameter.getMaxTimes();
    
            if(inputAndOutput.getMatrixColNums() != inputNum + outputNum){
                throw new Exception("神经元个数不符,请修改");
            }
            //初始化权值
            Matrix weightIJ = initWeight(inputNum, hiddenNum);
            Matrix weightJP = initWeight(hiddenNum, outputNum);
    
            //初始化阈值
            Matrix b1 = initThreshold(hiddenNum);
            Matrix b2 = initThreshold(outputNum);
    
            //动量项
            Matrix deltaWeightIJ0 = new Matrix(inputNum, hiddenNum);
            Matrix deltaWeightJP0 = new Matrix(hiddenNum, outputNum);
            Matrix deltaB10 = new Matrix(1, hiddenNum);
            Matrix deltaB20 = new Matrix(1, outputNum);
    
            Matrix input = new Matrix(new double[inputAndOutput.getMatrixRowNums()][inputNum]);
            Matrix output = new Matrix(new double[inputAndOutput.getMatrixRowNums()][outputNum]);
            for (int i = 0; i < inputAndOutput.getMatrixRowNums(); i++) {
                for (int j = 0; j < inputNum; j++) {
                    input.getMatrix()[i][j] = inputAndOutput.getValOfIdx(i,j);
                }
                for (int j = 0; j < inputAndOutput.getMatrixColNums() - inputNum; j++) {
                    output.getMatrix()[i][j] = inputAndOutput.getValOfIdx(i,inputNum+j);
                }
            }
    
            //归一化
            Map<String,Object> inputAfterNormalize = normalize(input, normalizationMin, normalizationMax);
            input = (Matrix) inputAfterNormalize.get("res");
            Matrix inputMax = (Matrix) inputAfterNormalize.get("max");
            Matrix inputMin = (Matrix) inputAfterNormalize.get("min");
            result.setInputMax(inputMax);
            result.setInputMin(inputMin);
    
            Map<String,Object> outputAfterNormalize = normalize(output, normalizationMin, normalizationMax);
            output = (Matrix) outputAfterNormalize.get("res");
            Matrix outputMax = (Matrix) outputAfterNormalize.get("max");
            Matrix outputMin = (Matrix) outputAfterNormalize.get("min");
            result.setOutputMax(outputMax);
            result.setOutputMin(outputMin);
    
            int times = 1;
            double E = 0;//误差
            while (times < maxTimes) {
                /*-----------------正向传播---------------------*/
                //隐含层输入
                Matrix jIn = input.multiple(weightIJ);
                double[][] b1CopyArr = new double[jIn.getMatrixRowNums()][b1.getMatrixRowNums()];
                //扩充阈值
                for (int i = 0; i < jIn.getMatrixRowNums(); i++) {
                    b1CopyArr[i] = b1.getMatrix()[0];
                }
                Matrix b1Copy = new Matrix(b1CopyArr);
                //加上阈值
                jIn = jIn.plus(b1Copy);
                //隐含层输出
                Matrix jOut = computeValue(jIn,activationFunction);
                //输出层输入
                Matrix pIn = jOut.multiple(weightJP);
                double[][] b2CopyArr = new double[pIn.getMatrixRowNums()][b2.getMatrixRowNums()];
                //扩充阈值
                for (int i = 0; i < pIn.getMatrixRowNums(); i++) {
                    b2CopyArr[i] = b2.getMatrix()[0];
                }
                Matrix b2Copy = new Matrix(b2CopyArr);
                //加上阈值
                pIn = pIn.plus(b2Copy);
                //输出层输出
                Matrix pOut = computeValue(pIn,activationFunction);
                //计算误差
                Matrix e = output.subtract(pOut);
                E = computeE(e);//误差
                //判断是否符合精度
                if (Math.abs(E) <= precision) {
                    System.out.println("满足精度");
                    break;
                }
    
                /*-----------------反向传播---------------------*/
                //J与P之间权值修正量
                Matrix deltaWeightJP = e.multiple(step);
                deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction));
                deltaWeightJP = deltaWeightJP.transpose().multiple(jOut);
                deltaWeightJP = deltaWeightJP.transpose();
                //P层神经元阈值修正量
                Matrix deltaThresholdP = e.multiple(step);
                deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction));
    
                //I与J之间的权值修正量
                Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction));
                Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose();
                Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction));
                deltaWeightIJ = input.transpose().multiple(deltaWeightIJ);
                deltaWeightIJ = deltaWeightIJ.multiple(step);
    
                //J层神经元阈值修正量
                Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction));
                deltaThresholdJ = deltaThresholdJ.multiple(-step);
    
                if (times == 1) {
                    //更新权值与阈值
                    weightIJ = weightIJ.plus(deltaWeightIJ);
                    weightJP = weightJP.plus(deltaWeightJP);
                    b1 = b1.plus(deltaThresholdJ);
                    b2 = b2.plus(deltaThresholdP);
                }else{
                    //加动量项
                    weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor));
                    weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor));
                    b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor));
                    b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor));
                }
    
                deltaWeightIJ0 = deltaWeightIJ;
                deltaWeightJP0 = deltaWeightJP;
                deltaB10 = deltaThresholdJ;
                deltaB20 = deltaThresholdP;
    
                times++;
            }
    
            result.setWeightIJ(weightIJ);
            result.setWeightJP(weightJP);
            result.setB1(b1);
            result.setB2(b2);
            result.setError(E);
            result.setTimes(times);
            System.out.println("循环次数:" + times + ",误差:" + E);
    
            return result;
        }
    
        /**
         * 计算BP神经网络的值
         * @param bpModel
         * @param input
         * @return
         */
        public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception {
            if (input.getMatrixColNums() != bpModel.getBpParameter().getInputLayerNeuronNum()) {
                throw new Exception("输入矩阵纬度有误");
            }
            ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction();
            Matrix weightIJ = bpModel.getWeightIJ();
            Matrix weightJP = bpModel.getWeightJP();
            Matrix b1 = bpModel.getB1();
            Matrix b2 = bpModel.getB2();
            double[][] normalizedInput = new double[input.getMatrixRowNums()][input.getMatrixColNums()];
            for (int i = 0; i < input.getMatrixRowNums(); i++) {
                for (int j = 0; j < input.getMatrixColNums(); j++) {
                    normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin()
                            + (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j))
                            / (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j))
                            * (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin());
                }
            }
            Matrix normalizedInputMatrix = new Matrix(normalizedInput);
            Matrix jIn = normalizedInputMatrix.multiple(weightIJ);
            double[][] b1CopyArr = new double[jIn.getMatrixRowNums()][b1.getMatrixRowNums()];
            //扩充阈值
            for (int i = 0; i < jIn.getMatrixRowNums(); i++) {
                b1CopyArr[i] = b1.getMatrix()[0];
            }
            Matrix b1Copy = new Matrix(b1CopyArr);
            //加上阈值
            jIn = jIn.plus(b1Copy);
            //隐含层输出
            Matrix jOut = computeValue(jIn,activationFunction);
            //输出层输入
            Matrix pIn = jOut.multiple(weightJP);
            double[][] b2CopyArr = new double[pIn.getMatrixRowNums()][b2.getMatrixRowNums()];
            //扩充阈值
            for (int i = 0; i < pIn.getMatrixRowNums(); i++) {
                b2CopyArr[i] = b2.getMatrix()[0];
            }
            Matrix b2Copy = new Matrix(b2CopyArr);
            //加上阈值
            pIn = pIn.plus(b2Copy);
            //输出层输出
            Matrix pOut = computeValue(pIn,activationFunction);
            //反归一化
            Matrix result = inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin());
    
            return result;
    
        }
    
        //初始化权值
        private Matrix initWeight(int x,int y){
            Random random=new Random();
            double[][] weight = new double[x][y];
            for (int i = 0; i < x; i++) {
                for (int j = 0; j < y; j++) {
                    weight[i][j] = 2*random.nextDouble()-1;
                }
            }
            return new Matrix(weight);
        }
        //初始化阈值
        private Matrix initThreshold(int x){
            Random random = new Random();
            double[][] result = new double[1][x];
            for (int i = 0; i < x; i++) {
                result[0][i] = 2*random.nextDouble()-1;
            }
            return new Matrix(result);
        }
    
        /**
         * 计算激活函数的值
         * @param a
         * @return
         */
        private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception {
            if (a.getMatrix() == null) {
                throw new Exception("参数值为空");
            }
            double[][] result = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
            for (int i = 0; i < a.getMatrixRowNums(); i++) {
                for (int j = 0; j < a.getMatrixColNums(); j++) {
                    result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j));
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 激活函数导数的值
         * @param a
         * @return
         */
        private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception {
            if (a.getMatrix() == null) {
                throw new Exception("参数值为空");
            }
            double[][] result = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
            for (int i = 0; i < a.getMatrixRowNums(); i++) {
                for (int j = 0; j < a.getMatrixColNums(); j++) {
                    result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j));
                }
            }
            return new Matrix(result);
        }
    
        /**
         * 数据归一化
         * @param a 要归一化的数据
         * @param normalizationMin  要归一化的区间下限
         * @param normalizationMax  要归一化的区间上限
         * @return
         */
        private Map<String, Object> normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception {
            HashMap<String, Object> result = new HashMap<>();
            double[][] maxArr = new double[1][a.getMatrixColNums()];
            double[][] minArr = new double[1][a.getMatrixColNums()];
            double[][] res = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
            for (int i = 0; i < a.getMatrixColNums(); i++) {
                List tmp = new ArrayList();
                for (int j = 0; j < a.getMatrixRowNums(); j++) {
                    tmp.add(a.getValOfIdx(j,i));
                }
                double max = (double) Collections.max(tmp);
                double min = (double) Collections.min(tmp);
                //数据归一化(注:若max与min均为0则不需要归一化)
                if (max != 0 || min != 0) {
                    for (int j = 0; j < a.getMatrixRowNums(); j++) {
                        res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin);
                    }
                }
                maxArr[0][i] = max;
                minArr[0][i] = min;
            }
            result.put("max", new Matrix(maxArr));
            result.put("min", new Matrix(minArr));
            result.put("res", new Matrix(res));
            return result;
        }
    
        /**
         * 反归一化
         * @param a 要反归一化的数据
         * @param normalizationMin 要反归一化的区间下限
         * @param normalizationMax 要反归一化的区间上限
         * @param dataMax   数据最大值
         * @param dataMin   数据最小值
         * @return
         */
        private Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin) throws Exception {
            double[][] res = new double[a.getMatrixRowNums()][a.getMatrixColNums()];
            for (int i = 0; i < a.getMatrixColNums(); i++) {
                //数据反归一化
                if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) {
                    for (int j = 0; j < a.getMatrixRowNums(); j++) {
                        res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin);
                    }
                }
            }
            return new Matrix(res);
        }
    
        /**
         * 计算误差
         * @param e
         * @return
         */
        private double computeE(Matrix e) throws Exception {
            e = e.square();
            return 0.5*e.sumAll();
        }
    
        /**
         * 将BP模型序列化到本地
         * @param bpModel
         * @throws IOException
         */
        public void serialize(BPModel bpModel,String path) throws IOException {
            File file = new File(path);
            System.out.println(file.getAbsolutePath());
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(file));
            out.writeObject(bpModel);
            out.close();
        }
    
        /**
         * 将BP模型反序列化
         * @return
         * @throws IOException
         * @throws ClassNotFoundException
         */
        public BPModel deSerialization(String path) throws IOException, ClassNotFoundException {
            File file = new File(path);
            ObjectInputStream oin = new ObjectInputStream(new FileInputStream(file));
            BPModel bpModel = (BPModel) oin.readObject(); // 强制转换到BPModel类型
            oin.close();
            return bpModel;
        }
    }
    BPNeuralNetworkFactory代码

    使用方式

    思路就是创建BPNeuralNetworkFactory对象,并传入BPParameter对象,调用BPNeuralNetworkFactory的trainBP(BPParameter bpParameter, Matrix inputAndOutput)方法,返回一个BPModel对象,可以使用BPNeuralNetworkFactory的序列化方法,将其序列化到本地,或者将其放到缓存中,使用时直接从本地反序列化获取到BPModel对象,调用BPNeuralNetworkFactory的computeBP(BPModel bpModel,Matrix input)方法,即可获取计算值。

    使用详情请看:https://github.com/ineedahouse/top-algorithm-set-doc/blob/master/doc/bpnn/BPNeuralNetwork.md

    源码github地址

    https://github.com/ineedahouse/top-algorithm-set

    对您有帮助的话,请点个Star~谢谢

    参考:基于BP神经网络的无约束优化方法研究及应用[D]. 赵逸翔.东北农业大学 2019

  • 相关阅读:
    centos 研究
    python学习6 web开发
    python学习5 常用三方模块
    python学习4 常用内置模块
    python学习 3笔记
    SQLite
    mysql
    python学习 2数学公式
    python学习 1基础
    shell example02
  • 原文地址:https://www.cnblogs.com/MrZhaoyx/p/13271832.html
Copyright © 2011-2022 走看看