zoukankan      html  css  js  c++  java
  • 机器学习 demo分西瓜

    周老师的书,对神经网络写了一个小的Demo

    是最简单的神经网络,只有一层的隐藏层。

    这次练习依旧是对西瓜的好坏进行预测。

    主要分了以下几个步骤

    1、数据预处理

    对西瓜的不同特性进行数学编码表示(0~1),我是直接编了对应数字。含糖量已经是一个0~1之间的数,所以就没有进行处理

    青绿  1

    乌黑 0.5

    浅白  0

    蜷缩  1

    稍蜷 0.5

    硬挺  0

    浊响  1

    沉闷 0.5

    清脆  0

    清晰  1

    稍糊 0.5

    模糊  0

    凹陷  1

    稍凹 0.5

    平坦  0

    硬滑  1

    软黏  0

    2、训练集和检测集

    [java] view plain copy
     
    1. package BP;  
    2.   
    3. public class TrainData {  
    4.     double[][] traindata;  
    5.     double[][] traindataoutput;  
    6.     double[][] testdata;  
    7.     double[][] testdataoutput;  
    8.     public TrainData(){  
    9.         traindata = new double[][]{  
    10.             new double[]{1,1,1,1,1,1,0.697,0.460},    
    11.             new double[]{0.5,1,0.5,1,1,1,0.774,0.376},  
    12.             new double[]{0.5,1,1,1,1,1,0.634,0.264},  
    13.             //new double[]{1,1,0.5,1,1,1,0.608,0.318,1},  
    14.             //new double[]{0,1,1,1,1,1,0.556,0.215,1},  
    15.             new double[]{1,0.5,1,1,0.5,0,0.403,0.237},  
    16.             new double[]{0.5,0.5,1,0.5,0.5,0,0.481,0.149},  
    17.             //new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211,1},  
    18.               
    19.             //new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091,0},  
    20.             //new double[]{1,0,0,1,0,0,0.243,0.267,0},  
    21.             //new double[]{0,0,0,0,0,1,0.245,0.057,0},  
    22.             //new double[]{0,1,1,0,0,0,0.343,0.099,0},  
    23.             new double[]{1,0.5,1,0.5,1,1,0.639,0.161},  
    24.             new double[]{0,0.5,0,0.5,1,1,0.657,0.198},  
    25.             new double[]{0.5,0.5,1,1,0.5,0,0.360,0.370},  
    26.             new double[]{0,1,1,0,0,1,0.593,0.042},  
    27.             new double[]{1,1,0.5,0.5,0.5,1,0.719,0.103}  
    28.         };  
    29.         traindataoutput = new double[][]{  
    30.             new double[]{1},  
    31.             new double[]{1},  
    32.             new double[]{1},  
    33.             new double[]{1},  
    34.             new double[]{1},  
    35.             new double[]{0},  
    36.             new double[]{0},  
    37.             new double[]{0},  
    38.             new double[]{0},  
    39.             new double[]{0},  
    40.         };  
    41.         testdata = new double[][]{  
    42.             new double[]{1,1,0.5,1,1,1,0.608,0.318},  
    43.             new double[]{0,1,1,1,1,1,0.556,0.215},  
    44.             new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211},  
    45.               
    46.             new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091},  
    47.             new double[]{1,0,0,1,0,0,0.243,0.267},  
    48.             new double[]{0,0,0,0,0,1,0.245,0.057},  
    49.             new double[]{0,1,1,0,0,0,0.343,0.099},  
    50.         };  
    51.         testdataoutput = new double[][]{  
    52.             new double[]{1},  
    53.             new double[]{1},  
    54.             new double[]{1},  
    55.             new double[]{0},  
    56.             new double[]{0},  
    57.             new double[]{0},  
    58.             new double[]{0},  
    59.         };  
    60.     }  
    61.     public static void main(String[] args){  
    62.         TrainData t = new TrainData();  
    63.         for(int i=0;i<t.traindata.length;i++){  
    64.             for(int j=0;j<9;j++)  
    65.                 System.out.print(t.traindata[i][j]+ " ");  
    66.             System.out.println();  
    67.         }  
    68.     }  
    69. }  

    3、BP主函数

    [java] view plain copy
     
    1. package BP;  
    2.   
    3. import java.util.Random;  
    4.   
    5. public class BP {  
    6.     int innum;  
    7.     int hiddennum;  
    8.     int outnum;  
    9.     //输入、隐藏、输出层  
    10.     public double[] input;  
    11.     public double[] hidden;  
    12.     //output为本神经网络计算出的输出值  
    13.     public double[] output;  
    14.   
    15.     //realoutput为训练网络时,用户提供的真的输出值  
    16.     public double[] realoutput;  
    17.   
    18.     //v[i,j]表示输入层i到隐层j  w[i,j]表示隐层i到输出层j  
    19.     public double[][] v;  
    20.     public double[][] w;  
    21.   
    22.     //beta为隐层的阈值,afa为输出层阈值  
    23.     public double[] beta;  
    24.     public double[] afa;  
    25.   
    26.     //学习率  
    27.     public double eta;  
    28.     //步长  
    29.     public double momentum;  
    30.     public final Random random;  
    31.   
    32.     public BP(int inputnum,int hiddennum,int outputnum,double learningrate){  
    33.         innum = inputnum;  
    34.         this.hiddennum = hiddennum;  
    35.         outnum = outputnum;  
    36.   
    37.         input = new double[inputnum + 1];  
    38.         hidden = new double[hiddennum + 1];  
    39.         output = new double[outputnum + 1];  
    40.         realoutput = new double[outputnum + 1];  
    41.   
    42.         v = new double[inputnum + 1][hiddennum + 1];  
    43.         w = new double[hiddennum + 1][outputnum + 1];  
    44.   
    45.         beta = new double[outputnum + 1];  
    46.         afa = new double[hiddennum + 1];  
    47.         for(int i=0;i<outputnum;i++)  
    48.             beta[i] = 0.0;  
    49.         for(int i=0;i<hiddennum;i++)  
    50.             afa[i] = 0.0;  
    51.   
    52.         eta = learningrate;  
    53.         //随机数对结果影响较大  
    54.         random = new Random(19950326);  
    55.         randomizeWeights(w);  
    56.         randomizeWeights(v);  
    57.     }  
    58.   
    59.     public void testData(double[] in){  
    60.         input = in;  
    61.         getNetOutput();  
    62.     }  
    63.     //只对本题目有用,output>0.5时为好西瓜,output<0.5时为坏西瓜  
    64.     public int predict(double[] in){  
    65.         testData(in);  
    66.         if(output[0]>0.5)  
    67.             return 1;  
    68.         else  
    69.             return 0;  
    70.     }  
    71.     //获得在test集上的正确率  
    72.     public double getAccuracy(double[][] in,double[][] out){  
    73.         int rightans = 0,wrongans = 0;  
    74.         for(int i=0;i<in.length;i++){  
    75.             if(predict(in[i])==(out[i][0])){  
    76.                 //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);  
    77.                 rightans++;  
    78.             }else{  
    79.                 //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);  
    80.                 wrongans++;  
    81.             }  
    82.         }  
    83.         System.out.println("对:"+rightans+" 错:"+wrongans);  
    84.         return (double)rightans/(double)(rightans+wrongans);  
    85.     }  
    86.     //times为进行几轮训练  
    87.     public void train(int times){  
    88.         TrainData t = new TrainData();  
    89.         double wu = 0.0,acc = 0.0;  
    90.         int n = t.traindata.length;  
    91.         for(int i=0;i<times;i++){  
    92.             wu = 0.0;  
    93.             for(int j=0;j<n;j++){  
    94.                 traindata(t.traindata[j],t.traindataoutput[j]);  
    95.                 wu += getDeviation();  
    96.             }  
    97.             wu = wu/((double)n);  
    98.             System.out.println("第"+i+"轮训练:"+wu);  
    99.             acc = getAccuracy(t.testdata,t.testdataoutput);  
    100.             System.out.println("预测正确率为: "+acc);  
    101.         }  
    102.     }  
    103.     //对一个input输入进行训练  
    104.     public void traindata(double[] in,double[] out){  
    105.         input = in;  
    106.         realoutput = out;  
    107.         getNetOutput();  
    108.         adjustParameter();  
    109.     }  
    110.     //获得误差E  
    111.     public double getDeviation(){  
    112.         double e = 0.0;  
    113.         for(int i=0;i<outnum;i++)  
    114.             e += (output[i] - realoutput[i])*(output[i] - realoutput[i]);  
    115.         e *= 0.5;  
    116.         return e;  
    117.     }  
    118.     //调整权值  
    119.     public void adjustParameter(){  
    120.         double g[],e = 0.0;  
    121.         g = new double[outnum];  
    122.         int i,j;  
    123.         for(i=0;i<outnum;i++){  
    124.             g[i] = output[i]*(1-output[i])*(realoutput[i]-output[i]);  
    125.             beta[i] -= eta * g[i];  
    126.             for(j=0;j<hiddennum;j++){  
    127.                 w[j][i] += eta * g[i] * hidden[j];  
    128.             }  
    129.         }  
    130.         for(i=0;i<hiddennum;i++){  
    131.             e = 0.0;  
    132.             for(j=0;j<outnum;j++)  
    133.                 e += g[j]*w[i][j];  
    134.             e = hidden[i]*(1-hidden[i])*e;  
    135.             afa[i] -= eta * e;  
    136.             for(j=0;j<innum;j++)  
    137.                 v[j][i] += eta * e * input[j];  
    138.         }  
    139.     }  
    140.     //获得output  
    141.     public void getNetOutput(){  
    142.         int i,j;  
    143.         double tmp=0.0;  
    144.         for(i=0;i<hiddennum;i++){  
    145.             tmp = 0.0;  
    146.             for(j=0;j<innum;j++)  
    147.                 tmp += v[j][i]*input[j];  
    148.             hidden[i] = sigmoid(tmp-afa[i]);  
    149.         }  
    150.         for(i=0;i<outnum;i++){  
    151.             tmp = 0.0;  
    152.             for(j=0;j<hiddennum;j++)  
    153.                 tmp += w[j][i]*hidden[j];  
    154.             output[i] = sigmoid(tmp-beta[i]);  
    155.         }  
    156.     }  
    157.     //对权值矩阵w、v进行初始随机化  
    158.     private void randomizeWeights(double[][] matrix) {  
    159.         for (int i = 0, len = matrix.length; i != len; i++)  
    160.             for (int j = 0, len2 = matrix[i].length; j != len2; j++) {  
    161.                 double real = random.nextDouble();  
    162.                 matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;  
    163.             }  
    164.     }  
    165.     public void debug(){  
    166.         System.out.println("========begin=======");  
    167.         for(int i=0;i<innum;i++){  
    168.             for(int j=0;j<hiddennum;j++)  
    169.                 System.out.print(v[i][j]+" ");  
    170.             System.out.println();  
    171.         }  
    172.         System.out.println();  
    173.         for(int i=0;i<hiddennum;i++){  
    174.             for(int j=0;j<outnum;j++)  
    175.                 System.out.print(w[i][j]+" ");  
    176.             System.out.println();  
    177.         }  
    178.         System.out.println("========end=======");  
    179.     }  
    180.     public double sigmoid(double z){  
    181.         double s = 0.0;  
    182.         s = 1d/(1d + Math.exp(-z));  
    183.         return s;  
    184.     }  
    185.   
    186.     public static void main(String[] args){  
    187.         BP bp = new BP(8,10,1,0.1);  
    188.         bp.train(50);  
    189.     }  
    190. }  


    我要说的:

            就结果来说,在验证集上的正确率可达到85%,当然很大程度上取决于BP初始化时random函数的种子。运气好的时候甚至能达到100%的正确率,运气不好的时候只有40%多,跟随便乱猜没什么区别。

            想问大神。。。只能采用这种随机算法来找到一个最合适的ramdom种子值嘛?能不能用遗传这样的开放式算法进行搜索来找到最合适的随机值(我觉得随机的种子和随机结果并没有什么直接的关联,所以不知道能不能用遗传算法之列。。。)

  • 相关阅读:
    DataAnnotations
    使用BizTalk实现RosettaNet B2B So Easy
    biztalk rosettanet 自定义 pip code
    Debatching(Splitting) XML Message in Orchestration using DefaultPipeline
    Modifying namespace in XML document programmatically
    IIS各个版本中你需要知道的那些事儿
    关于IHttpModule的相关知识总结
    开发设计的一些思想总结
    《ASP.NET SignalR系列》第五课 在MVC中使用SignalR
    《ASP.NET SignalR系列》第四课 SignalR自托管(不用IIS)
  • 原文地址:https://www.cnblogs.com/qiaoyanlin/p/6888617.html
Copyright © 2011-2022 走看看