zoukankan      html  css  js  c++  java
  • bp神经网络的实现C++

    #include <iostream>
    #include<stdlib.h>
    #include <math.h>
    using namespace std;
    
    #define  innode 2  
    #define  hiddennode 10
    #define  outnode 1 
    #define  sample 4
    class bpnet
    {
    public:
        double w1[hiddennode][innode];
        double w2[outnode][hiddennode];
        double b1[hiddennode];
        double b2[outnode];
        double e;
        double error;
        double lr;
        bpnet();
        ~bpnet();
        void init();
        double randval(double low, double high);
        void initw(double w[], int n);
        void train(double p[sample][innode], double t[sample][outnode]);
        double sigmod(double y);
        double dsigmod(double y);
        void predict(double p[]);
    };
    double bpnet::dsigmod(double y)
    {
        return y*(1 - y);
    }
    double bpnet::sigmod(double y)
    {
        return 1.0 / (1 + exp(-y));
    }
    double bpnet::randval(double low, double high)
    {
        double val;
        val = ((double)rand() / (double)RAND_MAX)*(high - low) + low;
        return val;
    }
    void bpnet::initw(double w[], int n)
    {
        for (int i = 0; i < n; i++)
        {
            w[i] = randval(-0.01, 0.01);
        }
    }
    void bpnet::init()
    {
        initw((double*)w1, hiddennode*innode);
        initw((double*)w2, hiddennode*outnode);
        initw(b1, hiddennode);
        initw(b2, outnode);
    }
    void bpnet::train(double p[sample][innode], double t[sample][outnode])
    {
        double hiddenerr[hiddennode];
        double outerr[outnode];
        double hiddenin[hiddennode];
        double hiddenout[hiddennode];
        double outin[outnode];
        double outout[outnode];
        double x[innode];
        double d[outnode];
        for (int k = 0; k < sample; k++)
        {
            for (int i = 0; i < innode; i++)
            {
                x[i] = p[k][i];
            }
            for (int i = 0; i < outnode; i++)
            {
                d[i] = t[k][i];
            }
            for (int i = 0; i < hiddennode; i++)
            {
                hiddenin[i] = 0.0;
                for (int j = 0; j < innode; j++)
                {
                    hiddenin[i] += w1[i][j] * x[j];
                }
                hiddenout[i] = sigmod(hiddenin[i] + b1[i]);
            }
            for (int i = 0; i < outnode; i++)
            {
                outin[i] = 0.0;
                for (int j = 0; j < hiddennode; j++)
                {
                    outin[i] += w2[i][j] * hiddenout[j];
                }
                outout[i] = sigmod(outin[i] + b2[i]);
            }
            for (int i = 0; i < outnode; i++)
            {
                outerr[i] = (d[i] - outout[i])*dsigmod(outout[i]);
                for (int j = 0; j < hiddennode; j++)
                {
                    w2[i][j] += lr*outerr[i] * hiddenout[j];
                }
            }
            for (int i = 0; i < hiddennode; i++)
            {
                hiddenerr[i] = 0.0;
                for (int j = 0; j < outnode; j++)
                {
                    hiddenerr[i] += w2[j][i] * outerr[j];
                }
                hiddenerr[i] = hiddenerr[i] * dsigmod(hiddenout[i]);
                for (int j = 0; j < innode; j++)
                {
                    w1[i][j] += lr*hiddenerr[i] * x[j];
                }
            }
            for (int i = 0; i < outnode; i++)
            {
                e += pow((d[i] - outout[i]), 2);
            }
            error = e / 2.0;
            for (int i = 0; i < outnode; i++)
            {
                b2[i]=lr*outerr[i];
            }
            for (int i = 0; i < hiddennode; i++)
            {
                b1[i] =hiddenerr[i] * lr;
            }
        }
    }
    void bpnet::predict(double p[])
    {
        double hiddenin[hiddennode];
        double hiddenout[hiddennode];
        double outin[outnode];
        double outout[outnode];
        double x[innode];
        for (int i = 0; i < innode; i++)
        {
            x[i] = p[i];
        }
        for (int i = 0; i < hiddennode; i++)
        {
            hiddenin[i] = 0.0;
            for (int j = 0; j < innode; j++)
            {
                hiddenin[i] += w1[i][j] * x[j];
            }
            hiddenout[i] = sigmod(hiddenin[i] + b1[i]);
        }
        for (int i = 0; i < outnode; i++)
        {
            outin[i] = 0.0;
            for (int j = 0; j < hiddennode; j++)
            {
                outin[i] += w2[i][j] * hiddenout[j];
            }
            outout[i] = sigmod(outin[i] + b2[i]);
        }
        for (int i = 0; i < outnode; i++)
        {
            cout << "the prediction is"<<outout[i] << endl;
        }
    }
    bpnet::bpnet()
    {
        e = 0.0;
        error = 1.0;
        lr = 0.4;
    }
    bpnet::~bpnet()
    {}
    
    double X[sample][innode] = {
        {1,1},
        {1,0},
        {0,1},
        {0,0}
    };
    double Y[sample][outnode] = {
        {1},
        {0},
        {0},
        {1}
    };
    int main()
    {
        bpnet bp;
        bp.init();
        int times = 0;
        while (bp.error > 0.001&&times <10000)
        {
            bp.e = 0.0;
            times++;
            bp.train(X, Y);
        }
        double m[2] = { 0,1 };
        bp.predict(m);
        return 0;
    }
  • 相关阅读:
    函数学习(JY07-JavaScript-JS基础03)
    数据类型转换中的一些特殊情况(JY06-JavaScript)
    JY05-JavsScript-JS基础01
    JY03-HTML/CSS-京东02
    JY02-HTML/CSS-京东01 定位是很粗暴的页面布局方法
    JY01-KX-01
    清浮动的几种方法
    HTML/CSS一些需要注意的基础知识
    linux 实时检测web项目MD5防止网站被黑
    面试总结第一谈
  • 原文地址:https://www.cnblogs.com/semen/p/6883438.html
Copyright © 2011-2022 走看看