zoukankan      html  css  js  c++  java
  • C++ LinearRegression代码实现

    这里基本完全参考网络资源完成,有疑问欢迎留言!

    LinearRegression.h

    #pragma once
    #ifndef ML_LINEAEEEGRESSION_H
    #define ML_LINEARREGRESSION_H
    class LinearRegression {
    public:
        /*特征*/
        double *x;
        /*预测值*/
        double *y;
        /*样本数量*/
        int m;
        /*系数*/
        double *theta;
        /*创建实例*/
        LinearRegression(double x[], double y[], int m);
        /*训练 */
        void train(double alpha, int iterations);
        /*预测*/
        double predict(double x);
    private:
        /*计算损失模型*/
        static double compute_cost(double x[], double y[], double theta[], int m);
        /*计算单个预测值*/
        static double h(double x, double theta[]);
        /*预测*/
        static double *calculate_predictions(double x[], double theta[], int m);
        /*梯度下降*/
        static double *gradient_descent(double x[], double y[], double alpha, int iter, double *j, int m);
    
    };
    #endif // !ML_LINEAEEEGRESSION_H

    LinearRegression.cpp

    #include "iostream"
    #include "linearRegression.h"
    #include "Utils.h"
    using namespace std;
    
    /*初始化*/
    LinearRegression::LinearRegression(double x[], double y[], int m)
    {
        this->x = x;
        this->y = y;
        this->m = m;
    }
    
    /*
    alpha:learn rate
    iterations:iterators
    */
    void LinearRegression::train(double alpha, int iterations)
    {
        double *J = new double[iterations];
        this->theta = gradient_descent(x, y, alpha, iterations, J, m);
        cout << "J=";
        for (int i = 0; i < iterations; ++i)
        {
            cout << J[i] << " " << endl;;
        }
        cout << "
    " << "Theta: " << theta[0] << " " << theta[1] << endl;
    }
    /*预测*/
    double LinearRegression::predict(double x)
    {
        cout << "y':" << h(x, theta) << endl;
        return h(x, theta);
    }
    
    /*计算损失模型*/
    double LinearRegression::compute_cost(double x[], double y[], double theta[], int m)
    {
        double *predictions = calculate_predictions(x, theta, m);
        double *diff = Utils::array_diff(predictions, y, m);
        double *sq_errors = Utils::array_pow(diff, m, 2);
        return (1.0 / (2 * m))*Utils::array_sum(sq_errors, m);
    }
    /*计算单个预测值*/
    double LinearRegression::h(double x, double theta[])
    {
        return theta[0] + theta[1] * x;
    }
    /*预测*/
    double *LinearRegression::calculate_predictions(double x[], double theta[], int m)
    {
        double *predictions = new double[m];
        for (int i = 0; i < m; i++)
        {
            predictions[i] = h(x[i], theta);
        }
        return predictions;
    }
    /*梯度下降*/
    double *LinearRegression::gradient_descent(double x[], double y[], double alpha, int iter, double *J, int m)
    {
        double *theta = new double[2];
        theta[0] = 1;
        theta[1] = 1;
        for (int i = 0; i < iter; i++)
        {
            double *predictions = calculate_predictions(x, theta, m);
            double *diff = Utils::array_diff(predictions, y, m);
            double *error_x1 = diff;
            double *error_x2 = Utils::array_multiplication(diff, x, m);
            /*这里可以设定J损失函数的阈值,也可以设定梯度变化量的阈值*/
            theta[0] = theta[0] - alpha*(1.0 / m) * Utils::array_sum(error_x1, m);
            theta[1] = theta[1] - alpha*(1.0 / m)*Utils::array_sum(error_x2, m);
            J[i] = compute_cost(x, y, theta, m);    
        }
        return theta;
    }

    Test.cpp

    #include "iostream"
    #include "linearRegression.h"
    
    using namespace std;
    
    int main()
    {
        double x[10] = {1,2,3,4,5};
        double y[10] = {4,6,8,11,12};
        
        LinearRegression test(x,y,10);
        test.train(0.1, 100);
        test.predict(7);
        system("pause");
        return 0;
    }
  • 相关阅读:
    让svn自动更新代码注释中的版本号
    前端开发利器F5
    当inlineblock和textindent遇到IE6,IE7
    DSL与函数式编程
    [译]当Node.js遇上WebMatrix 2
    《大道至简》的读后感
    深度学习之卷积神经网络之一
    ORACLE TRUNC()函数
    oracle rownum
    ORACLE 忽略已有重复值 创建唯一约束
  • 原文地址:https://www.cnblogs.com/zhibei/p/11787980.html
Copyright © 2011-2022 走看看