zoukankan      html  css  js  c++  java
  • 感知器算法 C++

     We can estimate the weight values for our training data using stochastic gradient descent.

    Stochastic gradient descent requires two parameters:

    • Learning Rate: Used to limit the amount each weight is corrected each time it is updated.
    • Epochs: The number of times to run through the training data while updating the weight.

    These, along with the training data will be the arguments to the function.

    There are 3 loops we need to perform in the function:

    1. Loop over each epoch.
    2. Loop over each row in the training data for an epoch.
    3. Loop over each weight and update it for a row in an epoch.

    As you can see, we update each weight for each row in the training data, each epoch.

    The loop is over until:

      the iteration error is less than a user-specified error threshold or

      a predetermined number of iterations have been completed.

    Weights are updated based on the error the model made. The error is calculated as the difference between the expected output value and the prediction made with the candidate weights.

    Notice that learning only occurs when an error is made, otherwise the weights are left unchanged.

    #include <iostream>
    #include <string>
    #include <fstream>
    #include <sstream>
    #include <vector>
    #include <cmath>

    //the sign function
    template <typename DataType, typename WeightType>
    double sign(typename::std::vector<DataType> &data, typename::std::vector<WeightType> &weights) {
        double result=0.0;

        for(size_t i=0; i<weights.size(); ++i) {
            result += data.at(i)*weights.at(i);
        }

        if(result >= 0.0)
            return 1.0;
        else
            return 0.0;
    }

    template <typename DataType, typename WeightType>
    void trainW(typename::std::vector<std::vector<DataType> > &vv, typename::std::vector<WeightType> &weights, const double& l_rate, const int& n_epoch) {
        std::vector<DataType> v_data;

        for(size_t i=0; i<weights.size(); ++i) {
            weights.at(i)=0.0;
        }

        for(size_t i=0; i<n_epoch; ++i) {
            double sum_error=0.0;

            for(size_t j=0; j<vv.size(); ++j) {
                v_data.clear();
                for(size_t k=0; k<weights.size(); ++k) {
                    v_data.push_back(vv[j][k]);
                }

                for(typename::std::vector<DataType>::iterator it=v_data.begin();it!=v_data.end();++it) {
                    std::cout<<*it<<" ";
                }

                std::cout<<std::endl;

                double prediction=sign(v_data, weights);
                double error=vv[j].back()-prediction;
                std::cout<<"expected: "<<vv[j].back()<<" prediction: "<<prediction<<" error: "<<error<<std::endl;

                sum_error+=pow(error, 2.0);

                for(size_t k=0; k<weights.size(); ++k) {
                    weights.at(k)=weights.at(k)+l_rate*error*vv[j][k];
                }
            }
            std::cout<<"epoch = "<<i<<" error = "<<sum_error<<std::endl;
        }

        for(size_t i=0; i<weights.size(); ++i) {
            std::cout<<weights.at(i)<<" ";
        }
        std::cout<<std::endl;
    }

    //make a prediction with weights, appended to the last column
    template <typename DataType, typename WeightType>
    void predictTestData(typename::std::vector<std::vector<DataType> > &vv, typename::std::vector<WeightType> &weights) {
        std::vector<DataType> v_data;

        for(size_t i=0;i<vv.size();++i) {
            v_data.clear();
            for(size_t j=0;j<weights.size();++j) {
                v_data.push_back(vv[i][j]);
            }

            double signResult=sign(v_data,weights);
            vv[i].push_back(signResult);
        }
    }

    //display the data
    template <typename DataType>
    void DisplayData(typename::std::vector<std::vector<DataType> > &vv) {
        std::cout<<"the number of data: "<<vv.size()<<std::endl;

        for(size_t i=0; i<vv.size(); ++i) {
            for(typename::std::vector<DataType>::iterator it=vv[i].begin(); it!=vv[i].end(); ++it) {
                std::cout<<*it<<" ";
            }
            std::cout<<std::endl;
        }
    }

    int main() {
        std::ifstream infile_feat("PLA.txt");
        std::string feature;
        float feat_onePoint;
        std::vector<float> lines;
        std::vector<std::vector<float> > lines_feat;
        lines_feat.clear();

        std::vector<float> v_weights;
        v_weights.clear();
        v_weights.push_back(-0.1);
        v_weights.push_back(0.206);
        v_weights.push_back(-0.234);

        while(!infile_feat.eof()) {
            getline(infile_feat, feature);
                if(feature.empty())
                    break;
            std::stringstream stringin(feature);
            lines.clear();

            lines.push_back(1);
            while(stringin >> feat_onePoint) {
                lines.push_back(feat_onePoint);
            }
            lines_feat.push_back(lines);
        }

        infile_feat.close();

        std::cout<<"display train data: "<<std::endl;

        DisplayData(lines_feat);

        double l_rate=0.1;

        int n_epoch=5;

        trainW(lines_feat, v_weights, l_rate, n_epoch);

        //predictTestData(lines_feat, v_weights);

        //std::cout<<"the predicted: "<<std::endl;
        //DisplayData(lines_feat);

        return 0;
    }

  • 相关阅读:
    ArcEngine实现对点、线、面的闪烁(转载)
    好久没写博客了.把这几个月的开发过程做一个总结
    利用暴力反编译的程序处理ArcXML数据遇到的问题小结(纯粹研究目的)
    ArcSde 9.2与Oracle 10g是最佳搭档
    当ArcEngine报事件同时存在于AxMapControl,MapControl时的解决方法(转载)
    写在苏州火炬接力的最后一站
    提问,如何才能触发鼠标事件
    地铁线路图高性能查找算法系统,最短路径查询地铁网络拓扑高效率算法原创附带demo
    二分查找
    .net面试题
  • 原文地址:https://www.cnblogs.com/donggongdechen/p/7768691.html
Copyright © 2011-2022 走看看