zoukankan      html  css  js  c++  java
  • 简单线性回归预测实现

    import numpy as np
    from .metrics import r2_score
    
    
    class SimpleLinearRegression:
    
        def __init__(self):
            """初始化Simple Linear Regression模型"""
            self.a_ = None
            self.b_ = None
    
        def fit(self, x_train, y_train):
            """根据训练数据集x_train, y_train训练Simple Linear Regression模型"""
            assert x_train.ndim == 1, 
                "Simple Linear Regressor can only solve single feature training data."
            assert len(x_train) == len(y_train), 
                "the size of x_train must be equal to the size of y_train"
    
            x_mean = np.mean(x_train)
            y_mean = np.mean(y_train)
    
            self.a_ = (x_train - x_mean).dot(y_train - y_mean) / (x_train - x_mean).dot(x_train - x_mean)#向量化 点乘运算
            self.b_ = y_mean - self.a_ * x_mean
    
            return self
    
        def predict(self, x_predict):
            """给定待预测数据集x_predict,返回表示x_predict的结果向量"""
            assert x_predict.ndim == 1, 
                "Simple Linear Regressor can only solve single feature training data."
            assert self.a_ is not None and self.b_ is not None, 
                "must fit before predict!"
    
            return np.array([self._predict(x) for x in x_predict])
    
        def _predict(self, x_single):
            """给定单个待预测数据x,返回x的预测结果值"""
            return self.a_ * x_single + self.b_
    
        def score(self, x_test, y_test):
            """根据测试数据集 x_test 和 y_test 确定当前模型的准确度"""
    
            y_predict = self.predict(x_test)
            return r2_score(y_test, y_predict)
    
        def __repr__(self):
            return "SimpleLinearRegression()"

    衡量线性回归模型误差的三种方式

    def mean_squared_error(y_true, y_predict):
        """计算y_true和y_predict之间的MSE"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum((y_true - y_predict)**2) / len(y_true)
    
    
    def root_mean_squared_error(y_true, y_predict):
        """计算y_true和y_predict之间的RMSE"""
    
        return sqrt(mean_squared_error(y_true, y_predict))
    
    
    def mean_absolute_error(y_true, y_predict):
        """计算y_true和y_predict之间的RMSE"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum(np.absolute(y_true - y_predict)) / len(y_true)

    计算最终模型预测准确度R方

    def r2_score(y_true, y_predict):
        """计算y_true和y_predict之间的R Square"""
    
        return 1 - mean_squared_error(y_true, y_predict)/np.var(y_true)
  • 相关阅读:
    由于挂载的nfs存储目录掉下线,导致创建VM时,无法创建
    使用RVM更新Ruby 版本
    安装logstash+kibana+elasticsearch+redis搭建集中式日志分析平台
    Topic modeling【经典模型】
    [第1集] 课程目标,数据类型,运算,变量
    Juint test Case 的2种使用方式
    getJSON方式请求服务器
    Web项目改名的带来的404not found问题
    JavaWeb EL表达式, JSTL标签及过滤器综合学习
    HashMap的几种遍历方式(转载)
  • 原文地址:https://www.cnblogs.com/Erick-L/p/9031064.html
Copyright © 2011-2022 走看看