zoukankan      html  css  js  c++  java
  • 机器学习 demo分西瓜

    周老师的书,对神经网络写了一个小的Demo

    是最简单的神经网络,只有一层的隐藏层。

    这次练习依旧是对西瓜的好坏进行预测。

    主要分了以下几个步骤

    1、数据预处理

    对西瓜的不同特性进行数学编码表示(0~1),我是直接编了对应数字。含糖量已经是一个0~1之间的数,所以就没有进行处理

    青绿  1

    乌黑 0.5

    浅白  0

    蜷缩  1

    稍蜷 0.5

    硬挺  0

    浊响  1

    沉闷 0.5

    清脆  0

    清晰  1

    稍糊 0.5

    模糊  0

    凹陷  1

    稍凹 0.5

    平坦  0

    硬滑  1

    软黏  0

    2、训练集和检测集

    [java] view plain copy
     
    1. package BP;  
    2.   
    3. public class TrainData {  
    4.     double[][] traindata;  
    5.     double[][] traindataoutput;  
    6.     double[][] testdata;  
    7.     double[][] testdataoutput;  
    8.     public TrainData(){  
    9.         traindata = new double[][]{  
    10.             new double[]{1,1,1,1,1,1,0.697,0.460},    
    11.             new double[]{0.5,1,0.5,1,1,1,0.774,0.376},  
    12.             new double[]{0.5,1,1,1,1,1,0.634,0.264},  
    13.             //new double[]{1,1,0.5,1,1,1,0.608,0.318,1},  
    14.             //new double[]{0,1,1,1,1,1,0.556,0.215,1},  
    15.             new double[]{1,0.5,1,1,0.5,0,0.403,0.237},  
    16.             new double[]{0.5,0.5,1,0.5,0.5,0,0.481,0.149},  
    17.             //new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211,1},  
    18.               
    19.             //new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091,0},  
    20.             //new double[]{1,0,0,1,0,0,0.243,0.267,0},  
    21.             //new double[]{0,0,0,0,0,1,0.245,0.057,0},  
    22.             //new double[]{0,1,1,0,0,0,0.343,0.099,0},  
    23.             new double[]{1,0.5,1,0.5,1,1,0.639,0.161},  
    24.             new double[]{0,0.5,0,0.5,1,1,0.657,0.198},  
    25.             new double[]{0.5,0.5,1,1,0.5,0,0.360,0.370},  
    26.             new double[]{0,1,1,0,0,1,0.593,0.042},  
    27.             new double[]{1,1,0.5,0.5,0.5,1,0.719,0.103}  
    28.         };  
    29.         traindataoutput = new double[][]{  
    30.             new double[]{1},  
    31.             new double[]{1},  
    32.             new double[]{1},  
    33.             new double[]{1},  
    34.             new double[]{1},  
    35.             new double[]{0},  
    36.             new double[]{0},  
    37.             new double[]{0},  
    38.             new double[]{0},  
    39.             new double[]{0},  
    40.         };  
    41.         testdata = new double[][]{  
    42.             new double[]{1,1,0.5,1,1,1,0.608,0.318},  
    43.             new double[]{0,1,1,1,1,1,0.556,0.215},  
    44.             new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211},  
    45.               
    46.             new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091},  
    47.             new double[]{1,0,0,1,0,0,0.243,0.267},  
    48.             new double[]{0,0,0,0,0,1,0.245,0.057},  
    49.             new double[]{0,1,1,0,0,0,0.343,0.099},  
    50.         };  
    51.         testdataoutput = new double[][]{  
    52.             new double[]{1},  
    53.             new double[]{1},  
    54.             new double[]{1},  
    55.             new double[]{0},  
    56.             new double[]{0},  
    57.             new double[]{0},  
    58.             new double[]{0},  
    59.         };  
    60.     }  
    61.     public static void main(String[] args){  
    62.         TrainData t = new TrainData();  
    63.         for(int i=0;i<t.traindata.length;i++){  
    64.             for(int j=0;j<9;j++)  
    65.                 System.out.print(t.traindata[i][j]+ " ");  
    66.             System.out.println();  
    67.         }  
    68.     }  
    69. }  

    3、BP主函数

    [java] view plain copy
     
    1. package BP;  
    2.   
    3. import java.util.Random;  
    4.   
    5. public class BP {  
    6.     int innum;  
    7.     int hiddennum;  
    8.     int outnum;  
    9.     //输入、隐藏、输出层  
    10.     public double[] input;  
    11.     public double[] hidden;  
    12.     //output为本神经网络计算出的输出值  
    13.     public double[] output;  
    14.   
    15.     //realoutput为训练网络时,用户提供的真的输出值  
    16.     public double[] realoutput;  
    17.   
    18.     //v[i,j]表示输入层i到隐层j  w[i,j]表示隐层i到输出层j  
    19.     public double[][] v;  
    20.     public double[][] w;  
    21.   
    22.     //beta为隐层的阈值,afa为输出层阈值  
    23.     public double[] beta;  
    24.     public double[] afa;  
    25.   
    26.     //学习率  
    27.     public double eta;  
    28.     //步长  
    29.     public double momentum;  
    30.     public final Random random;  
    31.   
    32.     public BP(int inputnum,int hiddennum,int outputnum,double learningrate){  
    33.         innum = inputnum;  
    34.         this.hiddennum = hiddennum;  
    35.         outnum = outputnum;  
    36.   
    37.         input = new double[inputnum + 1];  
    38.         hidden = new double[hiddennum + 1];  
    39.         output = new double[outputnum + 1];  
    40.         realoutput = new double[outputnum + 1];  
    41.   
    42.         v = new double[inputnum + 1][hiddennum + 1];  
    43.         w = new double[hiddennum + 1][outputnum + 1];  
    44.   
    45.         beta = new double[outputnum + 1];  
    46.         afa = new double[hiddennum + 1];  
    47.         for(int i=0;i<outputnum;i++)  
    48.             beta[i] = 0.0;  
    49.         for(int i=0;i<hiddennum;i++)  
    50.             afa[i] = 0.0;  
    51.   
    52.         eta = learningrate;  
    53.         //随机数对结果影响较大  
    54.         random = new Random(19950326);  
    55.         randomizeWeights(w);  
    56.         randomizeWeights(v);  
    57.     }  
    58.   
    59.     public void testData(double[] in){  
    60.         input = in;  
    61.         getNetOutput();  
    62.     }  
    63.     //只对本题目有用,output>0.5时为好西瓜,output<0.5时为坏西瓜  
    64.     public int predict(double[] in){  
    65.         testData(in);  
    66.         if(output[0]>0.5)  
    67.             return 1;  
    68.         else  
    69.             return 0;  
    70.     }  
    71.     //获得在test集上的正确率  
    72.     public double getAccuracy(double[][] in,double[][] out){  
    73.         int rightans = 0,wrongans = 0;  
    74.         for(int i=0;i<in.length;i++){  
    75.             if(predict(in[i])==(out[i][0])){  
    76.                 //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);  
    77.                 rightans++;  
    78.             }else{  
    79.                 //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);  
    80.                 wrongans++;  
    81.             }  
    82.         }  
    83.         System.out.println("对:"+rightans+" 错:"+wrongans);  
    84.         return (double)rightans/(double)(rightans+wrongans);  
    85.     }  
    86.     //times为进行几轮训练  
    87.     public void train(int times){  
    88.         TrainData t = new TrainData();  
    89.         double wu = 0.0,acc = 0.0;  
    90.         int n = t.traindata.length;  
    91.         for(int i=0;i<times;i++){  
    92.             wu = 0.0;  
    93.             for(int j=0;j<n;j++){  
    94.                 traindata(t.traindata[j],t.traindataoutput[j]);  
    95.                 wu += getDeviation();  
    96.             }  
    97.             wu = wu/((double)n);  
    98.             System.out.println("第"+i+"轮训练:"+wu);  
    99.             acc = getAccuracy(t.testdata,t.testdataoutput);  
    100.             System.out.println("预测正确率为: "+acc);  
    101.         }  
    102.     }  
    103.     //对一个input输入进行训练  
    104.     public void traindata(double[] in,double[] out){  
    105.         input = in;  
    106.         realoutput = out;  
    107.         getNetOutput();  
    108.         adjustParameter();  
    109.     }  
    110.     //获得误差E  
    111.     public double getDeviation(){  
    112.         double e = 0.0;  
    113.         for(int i=0;i<outnum;i++)  
    114.             e += (output[i] - realoutput[i])*(output[i] - realoutput[i]);  
    115.         e *= 0.5;  
    116.         return e;  
    117.     }  
    118.     //调整权值  
    119.     public void adjustParameter(){  
    120.         double g[],e = 0.0;  
    121.         g = new double[outnum];  
    122.         int i,j;  
    123.         for(i=0;i<outnum;i++){  
    124.             g[i] = output[i]*(1-output[i])*(realoutput[i]-output[i]);  
    125.             beta[i] -= eta * g[i];  
    126.             for(j=0;j<hiddennum;j++){  
    127.                 w[j][i] += eta * g[i] * hidden[j];  
    128.             }  
    129.         }  
    130.         for(i=0;i<hiddennum;i++){  
    131.             e = 0.0;  
    132.             for(j=0;j<outnum;j++)  
    133.                 e += g[j]*w[i][j];  
    134.             e = hidden[i]*(1-hidden[i])*e;  
    135.             afa[i] -= eta * e;  
    136.             for(j=0;j<innum;j++)  
    137.                 v[j][i] += eta * e * input[j];  
    138.         }  
    139.     }  
    140.     //获得output  
    141.     public void getNetOutput(){  
    142.         int i,j;  
    143.         double tmp=0.0;  
    144.         for(i=0;i<hiddennum;i++){  
    145.             tmp = 0.0;  
    146.             for(j=0;j<innum;j++)  
    147.                 tmp += v[j][i]*input[j];  
    148.             hidden[i] = sigmoid(tmp-afa[i]);  
    149.         }  
    150.         for(i=0;i<outnum;i++){  
    151.             tmp = 0.0;  
    152.             for(j=0;j<hiddennum;j++)  
    153.                 tmp += w[j][i]*hidden[j];  
    154.             output[i] = sigmoid(tmp-beta[i]);  
    155.         }  
    156.     }  
    157.     //对权值矩阵w、v进行初始随机化  
    158.     private void randomizeWeights(double[][] matrix) {  
    159.         for (int i = 0, len = matrix.length; i != len; i++)  
    160.             for (int j = 0, len2 = matrix[i].length; j != len2; j++) {  
    161.                 double real = random.nextDouble();  
    162.                 matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;  
    163.             }  
    164.     }  
    165.     public void debug(){  
    166.         System.out.println("========begin=======");  
    167.         for(int i=0;i<innum;i++){  
    168.             for(int j=0;j<hiddennum;j++)  
    169.                 System.out.print(v[i][j]+" ");  
    170.             System.out.println();  
    171.         }  
    172.         System.out.println();  
    173.         for(int i=0;i<hiddennum;i++){  
    174.             for(int j=0;j<outnum;j++)  
    175.                 System.out.print(w[i][j]+" ");  
    176.             System.out.println();  
    177.         }  
    178.         System.out.println("========end=======");  
    179.     }  
    180.     public double sigmoid(double z){  
    181.         double s = 0.0;  
    182.         s = 1d/(1d + Math.exp(-z));  
    183.         return s;  
    184.     }  
    185.   
    186.     public static void main(String[] args){  
    187.         BP bp = new BP(8,10,1,0.1);  
    188.         bp.train(50);  
    189.     }  
    190. }  


    我要说的:

            就结果来说,在验证集上的正确率可达到85%,当然很大程度上取决于BP初始化时random函数的种子。运气好的时候甚至能达到100%的正确率,运气不好的时候只有40%多,跟随便乱猜没什么区别。

            想问大神。。。只能采用这种随机算法来找到一个最合适的ramdom种子值嘛?能不能用遗传这样的开放式算法进行搜索来找到最合适的随机值(我觉得随机的种子和随机结果并没有什么直接的关联,所以不知道能不能用遗传算法之列。。。)

  • 相关阅读:
    scanf函数读入整数后接着读字符串的换行符残余问题
    二分查找的细节
    PAT 1016 Phone Bills
    1044. Shopping in Mars (25)-PAT甲级真题(二分查找)
    第0次作业
    FPGA开发流程
    quartus2 13.0+modelsim联合开发环境搭建(win10)
    50 years of Computer Architecture: From the Mainframe CPU to the Domain-Specific TPU and the Open RISC-V Instruction Set
    2018.11.26-11.30工作总结
    《步步惊“芯”——软核处理器内部设计分析》前两章读书笔记
  • 原文地址:https://www.cnblogs.com/qiaoyanlin/p/6888617.html
Copyright © 2011-2022 走看看