zoukankan      html  css  js  c++  java
  • BP神经网络的Java实现(转)

    http://fantasticinblur.iteye.com/blog/1465497

    课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

    下面是BPNN的实现代码。类名为BP。

    Java代码  收藏代码
    1. package ml;  
    2.   
    3. import java.util.Random;  
    4.   
    5. /** 
    6.  * BPNN. 
    7.  *  
    8.  * @author RenaQiu 
    9.  *  
    10.  */  
    11. public class BP {  
    12.     /** 
    13.      * input vector. 
    14.      */  
    15.     private final double[] input;  
    16.     /** 
    17.      * hidden layer. 
    18.      */  
    19.     private final double[] hidden;  
    20.     /** 
    21.      * output layer. 
    22.      */  
    23.     private final double[] output;  
    24.     /** 
    25.      * target. 
    26.      */  
    27.     private final double[] target;  
    28.   
    29.     /** 
    30.      * delta vector of the hidden layer . 
    31.      */  
    32.     private final double[] hidDelta;  
    33.     /** 
    34.      * output layer of the output layer. 
    35.      */  
    36.     private final double[] optDelta;  
    37.   
    38.     /** 
    39.      * learning rate. 
    40.      */  
    41.     private final double eta;  
    42.     /** 
    43.      * momentum. 
    44.      */  
    45.     private final double momentum;  
    46.   
    47.     /** 
    48.      * weight matrix from input layer to hidden layer. 
    49.      */  
    50.     private final double[][] iptHidWeights;  
    51.     /** 
    52.      * weight matrix from hidden layer to output layer. 
    53.      */  
    54.     private final double[][] hidOptWeights;  
    55.   
    56.     /** 
    57.      * previous weight update. 
    58.      */  
    59.     private final double[][] iptHidPrevUptWeights;  
    60.     /** 
    61.      * previous weight update. 
    62.      */  
    63.     private final double[][] hidOptPrevUptWeights;  
    64.   
    65.     public double optErrSum = 0d;  
    66.   
    67.     public double hidErrSum = 0d;  
    68.   
    69.     private final Random random;  
    70.   
    71.     /** 
    72.      * Constructor. 
    73.      * <p> 
    74.      * <strong>Note:</strong> The capacity of each layer will be the parameter 
    75.      * plus 1. The additional unit is used for smoothness. 
    76.      * </p> 
    77.      *  
    78.      * @param inputSize 
    79.      * @param hiddenSize 
    80.      * @param outputSize 
    81.      * @param eta 
    82.      * @param momentum 
    83.      * @param epoch 
    84.      */  
    85.     public BP(int inputSize, int hiddenSize, int outputSize, double eta,  
    86.             double momentum) {  
    87.   
    88.         input = new double[inputSize + 1];  
    89.         hidden = new double[hiddenSize + 1];  
    90.         output = new double[outputSize + 1];  
    91.         target = new double[outputSize + 1];  
    92.   
    93.         hidDelta = new double[hiddenSize + 1];  
    94.         optDelta = new double[outputSize + 1];  
    95.   
    96.         iptHidWeights = new double[inputSize + 1][hiddenSize + 1];  
    97.         hidOptWeights = new double[hiddenSize + 1][outputSize + 1];  
    98.   
    99.         random = new Random(19881211);  
    100.         randomizeWeights(iptHidWeights);  
    101.         randomizeWeights(hidOptWeights);  
    102.   
    103.         iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];  
    104.         hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];  
    105.   
    106.         this.eta = eta;  
    107.         this.momentum = momentum;  
    108.     }  
    109.   
    110.     private void randomizeWeights(double[][] matrix) {  
    111.         for (int i = 0, len = matrix.length; i != len; i++)  
    112.             for (int j = 0, len2 = matrix[i].length; j != len2; j++) {  
    113.                 double real = random.nextDouble();  
    114.                 matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;  
    115.             }  
    116.     }  
    117.   
    118.     /** 
    119.      * Constructor with default eta = 0.25 and momentum = 0.3. 
    120.      *  
    121.      * @param inputSize 
    122.      * @param hiddenSize 
    123.      * @param outputSize 
    124.      * @param epoch 
    125.      */  
    126.     public BP(int inputSize, int hiddenSize, int outputSize) {  
    127.         this(inputSize, hiddenSize, outputSize, 0.25, 0.9);  
    128.     }  
    129.   
    130.     /** 
    131.      * Entry method. The train data should be a one-dim vector. 
    132.      *  
    133.      * @param trainData 
    134.      * @param target 
    135.      */  
    136.     public void train(double[] trainData, double[] target) {  
    137.         loadInput(trainData);  
    138.         loadTarget(target);  
    139.         forward();  
    140.         calculateDelta();  
    141.         adjustWeight();  
    142.     }  
    143.   
    144.     /** 
    145.      * Test the BPNN. 
    146.      *  
    147.      * @param inData 
    148.      * @return 
    149.      */  
    150.     public double[] test(double[] inData) {  
    151.         if (inData.length != input.length - 1) {  
    152.             throw new IllegalArgumentException("Size Do Not Match.");  
    153.         }  
    154.         System.arraycopy(inData, 0, input, 1, inData.length);  
    155.         forward();  
    156.         return getNetworkOutput();  
    157.     }  
    158.   
    159.     /** 
    160.      * Return the output layer. 
    161.      *  
    162.      * @return 
    163.      */  
    164.     private double[] getNetworkOutput() {  
    165.         int len = output.length;  
    166.         double[] temp = new double[len - 1];  
    167.         for (int i = 1; i != len; i++)  
    168.             temp[i - 1] = output[i];  
    169.         return temp;  
    170.     }  
    171.   
    172.     /** 
    173.      * Load the target data. 
    174.      *  
    175.      * @param arg 
    176.      */  
    177.     private void loadTarget(double[] arg) {  
    178.         if (arg.length != target.length - 1) {  
    179.             throw new IllegalArgumentException("Size Do Not Match.");  
    180.         }  
    181.         System.arraycopy(arg, 0, target, 1, arg.length);  
    182.     }  
    183.   
    184.     /** 
    185.      * Load the training data. 
    186.      *  
    187.      * @param inData 
    188.      */  
    189.     private void loadInput(double[] inData) {  
    190.         if (inData.length != input.length - 1) {  
    191.             throw new IllegalArgumentException("Size Do Not Match.");  
    192.         }  
    193.         System.arraycopy(inData, 0, input, 1, inData.length);  
    194.     }  
    195.   
    196.     /** 
    197.      * Forward. 
    198.      *  
    199.      * @param layer0 
    200.      * @param layer1 
    201.      * @param weight 
    202.      */  
    203.     private void forward(double[] layer0, double[] layer1, double[][] weight) {  
    204.         // threshold unit.  
    205.         layer0[0] = 1.0;  
    206.         for (int j = 1, len = layer1.length; j != len; ++j) {  
    207.             double sum = 0;  
    208.             for (int i = 0, len2 = layer0.length; i != len2; ++i)  
    209.                 sum += weight[i][j] * layer0[i];  
    210.             layer1[j] = sigmoid(sum);  
    211.         }  
    212.     }  
    213.   
    214.     /** 
    215.      * Forward. 
    216.      */  
    217.     private void forward() {  
    218.         forward(input, hidden, iptHidWeights);  
    219.         forward(hidden, output, hidOptWeights);  
    220.     }  
    221.   
    222.     /** 
    223.      * Calculate output error. 
    224.      */  
    225.     private void outputErr() {  
    226.         double errSum = 0;  
    227.         for (int idx = 1, len = optDelta.length; idx != len; ++idx) {  
    228.             double o = output[idx];  
    229.             optDelta[idx] = o * (1d - o) * (target[idx] - o);  
    230.             errSum += Math.abs(optDelta[idx]);  
    231.         }  
    232.         optErrSum = errSum;  
    233.     }  
    234.   
    235.     /** 
    236.      * Calculate hidden errors. 
    237.      */  
    238.     private void hiddenErr() {  
    239.         double errSum = 0;  
    240.         for (int j = 1, len = hidDelta.length; j != len; ++j) {  
    241.             double o = hidden[j];  
    242.             double sum = 0;  
    243.             for (int k = 1, len2 = optDelta.length; k != len2; ++k)  
    244.                 sum += hidOptWeights[j][k] * optDelta[k];  
    245.             hidDelta[j] = o * (1d - o) * sum;  
    246.             errSum += Math.abs(hidDelta[j]);  
    247.         }  
    248.         hidErrSum = errSum;  
    249.     }  
    250.   
    251.     /** 
    252.      * Calculate errors of all layers. 
    253.      */  
    254.     private void calculateDelta() {  
    255.         outputErr();  
    256.         hiddenErr();  
    257.     }  
    258.   
    259.     /** 
    260.      * Adjust the weight matrix. 
    261.      *  
    262.      * @param delta 
    263.      * @param layer 
    264.      * @param weight 
    265.      * @param prevWeight 
    266.      */  
    267.     private void adjustWeight(double[] delta, double[] layer,  
    268.             double[][] weight, double[][] prevWeight) {  
    269.   
    270.         layer[0] = 1;  
    271.         for (int i = 1, len = delta.length; i != len; ++i) {  
    272.             for (int j = 0, len2 = layer.length; j != len2; ++j) {  
    273.                 double newVal = momentum * prevWeight[j][i] + eta * delta[i]  
    274.                         * layer[j];  
    275.                 weight[j][i] += newVal;  
    276.                 prevWeight[j][i] = newVal;  
    277.             }  
    278.         }  
    279.     }  
    280.   
    281.     /** 
    282.      * Adjust all weight matrices. 
    283.      */  
    284.     private void adjustWeight() {  
    285.         adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);  
    286.         adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);  
    287.     }  
    288.   
    289.     /** 
    290.      * Sigmoid. 
    291.      *  
    292.      * @param val 
    293.      * @return 
    294.      */  
    295.     private double sigmoid(double val) {  
    296.         return 1d / (1d + Math.exp(-val));  
    297.     }  
    298. }  

     为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

    训练样本为1000个,学习200次。

    Java代码  收藏代码
    1. package ml;  
    2.   
    3. import java.io.IOException;  
    4. import java.util.ArrayList;  
    5. import java.util.List;  
    6. import java.util.Random;  
    7.   
    8. public class Test {  
    9.   
    10.     /** 
    11.      * @param args 
    12.      * @throws IOException 
    13.      */  
    14.     public static void main(String[] args) throws IOException {  
    15.         BP bp = new BP(32, 15, 4);  
    16.   
    17.         Random random = new Random();  
    18.         List<Integer> list = new ArrayList<Integer>();  
    19.         for (int i = 0; i != 1000; i++) {  
    20.             int value = random.nextInt();  
    21.             list.add(value);  
    22.         }  
    23.   
    24.         for (int i = 0; i != 200; i++) {  
    25.             for (int value : list) {  
    26.                 double[] real = new double[4];  
    27.                 if (value >= 0)  
    28.                     if ((value & 1) == 1)  
    29.                         real[0] = 1;  
    30.                     else  
    31.                         real[1] = 1;  
    32.                 else if ((value & 1) == 1)  
    33.                     real[2] = 1;  
    34.                 else  
    35.                     real[3] = 1;  
    36.                 double[] binary = new double[32];  
    37.                 int index = 31;  
    38.                 do {  
    39.                     binary[index--] = (value & 1);  
    40.                     value >>>= 1;  
    41.                 } while (value != 0);  
    42.   
    43.                 bp.train(binary, real);  
    44.             }  
    45.         }  
    46.   
    47.         System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");  
    48.   
    49.         while (true) {  
    50.             byte[] input = new byte[10];  
    51.             System.in.read(input);  
    52.             Integer value = Integer.parseInt(new String(input).trim());  
    53.             int rawVal = value;  
    54.             double[] binary = new double[32];  
    55.             int index = 31;  
    56.             do {  
    57.                 binary[index--] = (value & 1);  
    58.                 value >>>= 1;  
    59.             } while (value != 0);  
    60.   
    61.             double[] result = bp.test(binary);  
    62.   
    63.             double max = -Integer.MIN_VALUE;  
    64.             int idx = -1;  
    65.   
    66.             for (int i = 0; i != result.length; i++) {  
    67.                 if (result[i] > max) {  
    68.                     max = result[i];  
    69.                     idx = i;  
    70.                 }  
    71.             }  
    72.   
    73.             switch (idx) {  
    74.             case 0:  
    75.                 System.out.format("%d是一个正奇数 ", rawVal);  
    76.                 break;  
    77.             case 1:  
    78.                 System.out.format("%d是一个正偶数 ", rawVal);  
    79.                 break;  
    80.             case 2:  
    81.                 System.out.format("%d是一个负奇数 ", rawVal);  
    82.                 break;  
    83.             case 3:  
    84.                 System.out.format("%d是一个负偶数 ", rawVal);  
    85.                 break;  
    86.             }  
    87.         }  
    88.     }  
    89.   
    90. }  

     运行结果截图如下:



     这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

  • 相关阅读:
    from fake_useragent import UserAgent
    teamviewer 安装 仅学习
    利用pandas 中的read_html 获取页面表格
    第十二天 最恶心的考试题
    第十三天 生成器和生成器函数, 列表推导式
    第十一天 函数名的使用以及第一类对象, 闭包, 迭代器
    第十天 动态参数,名称空间,作用域,函数的嵌套,gloabal / nonlocal 关键字
    初始函数, 函数的定义,函数名,函数体以及函数的调用,函数的返回值,函数的参数
    第八天 文件的读,写,追加,读写,写读,seek()光标的移动,修改文件以及另一种打开文件的方式
    第七天 1.基础数据类型的补充 2.set集合 3.深浅拷贝
  • 原文地址:https://www.cnblogs.com/bnuvincent/p/6476040.html
Copyright © 2011-2022 走看看