zoukankan      html  css  js  c++  java
  • 简单的神经网络模型java版本

    代码参考http://blog.csdn.net/luxiaoxun/article/details/7649945,原文是个c++,输入二进制输出对应的十进制;学习后,用java重写了下。

    程序描述:输入小写的abcdefghi,会输出它的大写形式

    代码中每行第一次出现//后的内容是原c++程序的注释,

    eg:在void main()中定义char m='a',会输出A;

    package cn.kelaile.ocr;


    public class Bpneuralnettest {


    public staticint innode= 4 ;//输入结点数

    public staticinthidenode= 10;//隐含结点数

    public staticint outnode= 1 ;//输出结点数

    public staticint trainsample=8;//BP训练样本数

    public staticdouble [][]w=newdouble [innode][hidenode];//隐含结点权值

    public staticdouble [][]w1=newdouble[hidenode][outnode];//输出结点权值

    public staticdouble []b1=newdouble[hidenode];//隐含结点阀值

    public staticdouble []b2=newdouble[outnode];//输出结点阀值


    public staticdoublerate_w;//权值学习率(输入层-隐含层)

    public staticdoublerate_w1;//权值学习率 (隐含层-输出层)

    public staticdoublerate_b1;//隐含层阀值学习率

    public staticdoublerate_b2;//输出层阀值学习率


    public staticdoublee;//误差计算

    public staticdoubleerror;//允许的最大误差

    public staticdouble []result=newdouble[outnode];//Bp输出


    public staticdouble [] pattern(charpp){

     

    if(pp=='a'){

    double []p={0,0,0,0};

    returnp;

    }else if(pp=='b'){

    double []p={0,0,1,0};

    returnp;

    }else if(pp=='c'){

    double []p={0,1,0,0};

    returnp;

    }else if(pp=='d'){

    double []p={0,1,1,0};

    returnp;

    }else if(pp=='e'){

    double []p={1,0,0,0};

    returnp;

    }else if(pp=='f'){

    double []p={1,0,1,0};

    returnp;

    }else if(pp=='g'){

    double []p={1,1,0,0};

    returnp;

    }else if(pp=='h'){

    double []p={1,1,1,0};

    returnp;

    }else if(pp=='i'){

    double []p={1,1,1,1};

    returnp;

    }else{

    System.out.println("参数为abcdefgh");

    returnnull;

    }

    }

    public staticdouble [] recognize(charpp)

    {

        double []x=newdouble [innode];//输入向量

        double []x1=newdouble[hidenode];//隐含结点状态值

        double []x2=newdouble[outnode];//输出结点状态值

        double []o1=newdouble[hidenode];//隐含层激活值

        double []o2=newdouble[hidenode];//输出层激活值

        

       double []p=pattern(pp);

        for(inti=0;i<innode;i++)

            x[i]=p[i];

        

        for(intj=0;j<hidenode;j++)

        {

            o1[j]=0.0;

            for(inti=0;i<innode;i++)

                o1[j]=o1[j]+w[i][j]*x[i];//隐含层各单元激活值

            x1[j]=1.0/(1.0+Math.pow(Math.E, -o1[j]-b1[j]));//隐含层各单元输出

            //if(o1[j]+b1[j]>0) x1[j]=1;

            //    else x1[j]=0;

        }

        

        for(intk=0;k<outnode;k++)

        {

            o2[k]=0.0;

            for(intj=0;j<hidenode;j++)

                o2[k]=o2[k]+w1[j][k]*x1[j];//输出层各单元激活值

            x2[k]=1.0/(1.0+Math.pow(Math.E, (-o2[k]-b2[k])));//输出层各单元输出

            //if(o2[k]+b2[k]>0) x2[k]=1;

            //else x2[k]=0;

        }

        

        for(intk=0;k<outnode;k++)

        {

            result[k]=x2[k];

        }

        return result;

    }

    public staticvoid train(double [][]p,double [][]t)

    {

        double []pp=newdouble[hidenode];//隐含结点的校正误差

        double []qq=newdouble[outnode];//希望输出值与实际输出值的偏差

        double []yd=newdouble[outnode];//希望输出值

        

        double []x=newdouble[innode];//输入向量

        double []x1=newdouble[hidenode];//隐含结点状态值

        double []x2=newdouble[outnode];//输出结点状态值

        double []o1=newdouble[hidenode];//隐含层激活值

        double []o2=newdouble[hidenode];//输出层激活值

        

        for(intisamp=0;isamp<trainsample;isamp++)//循环训练一次样品

        {

            for(inti=0;i<innode;i++)

                x[i]=p[isamp][i];//输入的样本

            for(inti=0;i<outnode;i++)

                yd[i]=t[isamp][i];//期望输出的样本

            

            //构造每个样品的输入和输出标准

            for(intj=0;j<hidenode;j++)

            {

                o1[j]=0.0;

                for(inti=0;i<innode;i++)

                    o1[j]=o1[j]+w[i][j]*x[i];//隐含层各单元输入激活值

                x1[j]=1.0/(1+Math.pow(Math.E, (-o1[j]-b1[j])));//隐含层各单元的输出

                //    if(o1[j]+b1[j]>0) x1[j]=1;

                //else x1[j]=0;

            }

            

            for(intk=0;k<outnode;k++)

            {

                o2[k]=0.0;

                for(intj=0;j<hidenode;j++)

                    o2[k]=o2[k]+w1[j][k]*x1[j];//输出层各单元输入激活值

                x2[k]=1.0/(1.0+Math.pow(Math.E, (-o2[k]-b2[k])));//输出层各单元输出//这里就是用到sigmod函数了

                //    if(o2[k]+b2[k]>0) x2[k]=1;

                //    else x2[k]=0;

            }

            

            for(intk=0;k<outnode;k++)

            {

                qq[k]=(yd[k]-x2[k])*x2[k]*(1-x2[k]);//希望输出与实际输出的偏差//用y表示simod函数,可知道y‘=y(1-y),这个自己推导下就知道了

                for(intj=0;j<hidenode;j++)

                    w1[j][k]+=rate_w1*qq[k]*x1[j]; //下一次的隐含层和输出层之间的新连接权

            }

            

            for(intj=0;j<hidenode;j++)

            {

                pp[j]=0.0;

                for(intk=0;k<outnode;k++)

                    pp[j]=pp[j]+qq[k]*w1[j][k];

                pp[j]=pp[j]*x1[j]*(1-x1[j]);//隐含层的校正误差

                

                for(inti=0;i<innode;i++)

                    w[i][j]+=rate_w*pp[j]*x[i];//下一次的输入层和隐含层之间的新连接权

            }

            

            for(intk=0;k<outnode;k++)

            {

                e+=Math.abs(yd[k]-x2[k])*Math.abs(yd[k]-x2[k]);//计算均方差

            }

            error=e/2.0;

            

            for(intk=0;k<outnode;k++)

                b2[k]=b2[k]+rate_b2*qq[k];//下一次的隐含层和输出层之间的新阈值//这个就是调整权值很常用的求法或者建模方法了,hash映射时候也是这么求的。

            for(intj=0;j<hidenode;j++)

                b1[j]=b1[j]+rate_b1*pp[j];//下一次的输入层和隐含层之间的新阈值

        }

    }

    //输入样本

    public staticdouble [][]X= {

      {0,0,0,0},{0,0,1,0},{0,1,0,0},{0,1,1,0},{1,0,0,0},{1,0,1,0},{1,1,0,0},{1,1,1,0},{1,1,1,1}

    };

    public static String[]Sample={"ABCDEFGH"


    };

    //期望输出样本

    public staticdouble [][]Y={

      {0},{0.125},{0.250},{0.375},{0.500},{0.625},{0.850},{0.975},{1.0000}

    };

    public staticvoid winit(doublew[],intn)//权值初始化

    {

        for(inti=0;i<n;i++)

            w[i]=(2.0*(double)Math.random())-1;

    }

    public staticvoid winit(doublew[][],intn)//权值初始化

    {

    for(intj=0;j<w.length;j++){


        for(inti=0;i<w[j].length;i++){

            w[j][i]=(2.0*(double)Math.random())-1;

            }

        }

    }



    public staticvoid init(){

    winit(w,innode*hidenode);

        winit(w1,hidenode*outnode);

        winit(b1,hidenode);

        winit(b2,outnode);

    }


    public static void main(String[]args) {

    // TODO Auto-generated method stub

    error= 1.0;

        e=0.0f;

       

        rate_w= 0.9; //权值学习率(输入层--隐含层)

        rate_w1= 0.9;//权值学习率 (隐含层--输出层)

        rate_b1= 0.9; //隐含层阀值学习率

        rate_b2= 0.9; //输出层阀值学习率

        init();

        int times=0;

        while(error>0.0001)

        {

            e=0.0;

            times++;

            train(X,Y);

            System.out.println("Times="+times+" error="+error);

           

        }

        System.out.println("trainning complete...");

       

        char m='f';

        double[] r=recognize(m);

        for(inti=0;i<outnode;++i){

           

        System.out.println(result[i]+" ");

        }

        double [][]cha=newdouble[trainsample][outnode];

        double mi=100;

        int index=99999;

        for(inti=0;i<trainsample;i++)

        {

            for(intj=0;j<outnode;j++)

            {

                //找差值最小的那个样本

                cha[i][j]=(double)(Math.abs(Y[i][j]-result[j]));

                if(cha[i][j]<mi)

                {

                    mi=cha[i][j];

                    index=i;

                }

            }

        }

       

            System.out.println(m);

       

        String res="ABCDEFGHI";

        char[] ch = res.toCharArray();

       

        System.out.println(" is "+ch[index]);

       

       


    }


    }


  • 相关阅读:
    从今天开始,记录学习的点滴。
    git命令整理
    vue ie报错:[Vue warn]: Error in v-on handler: "ReferenceError: “Promise”未定义"
    HTML5知识整理
    解决You are using the runtime-only build of Vue where the template compiler is not available. Either pre
    HTML5本地存储
    网站建设流程图说明
    vue支持的修饰符(常用整理)
    vue绑定内联样式
    vue绑定class的几种方式
  • 原文地址:https://www.cnblogs.com/zhangdebin/p/5567955.html
Copyright © 2011-2022 走看看