zoukankan      html  css  js  c++  java
  • 简单线性回归预测实现

    import numpy as np
    from .metrics import r2_score
    
    
    class SimpleLinearRegression:
    
        def __init__(self):
            """初始化Simple Linear Regression模型"""
            self.a_ = None
            self.b_ = None
    
        def fit(self, x_train, y_train):
            """根据训练数据集x_train, y_train训练Simple Linear Regression模型"""
            assert x_train.ndim == 1, 
                "Simple Linear Regressor can only solve single feature training data."
            assert len(x_train) == len(y_train), 
                "the size of x_train must be equal to the size of y_train"
    
            x_mean = np.mean(x_train)
            y_mean = np.mean(y_train)
    
            self.a_ = (x_train - x_mean).dot(y_train - y_mean) / (x_train - x_mean).dot(x_train - x_mean)#向量化 点乘运算
            self.b_ = y_mean - self.a_ * x_mean
    
            return self
    
        def predict(self, x_predict):
            """给定待预测数据集x_predict,返回表示x_predict的结果向量"""
            assert x_predict.ndim == 1, 
                "Simple Linear Regressor can only solve single feature training data."
            assert self.a_ is not None and self.b_ is not None, 
                "must fit before predict!"
    
            return np.array([self._predict(x) for x in x_predict])
    
        def _predict(self, x_single):
            """给定单个待预测数据x,返回x的预测结果值"""
            return self.a_ * x_single + self.b_
    
        def score(self, x_test, y_test):
            """根据测试数据集 x_test 和 y_test 确定当前模型的准确度"""
    
            y_predict = self.predict(x_test)
            return r2_score(y_test, y_predict)
    
        def __repr__(self):
            return "SimpleLinearRegression()"

    衡量线性回归模型误差的三种方式

    def mean_squared_error(y_true, y_predict):
        """计算y_true和y_predict之间的MSE"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum((y_true - y_predict)**2) / len(y_true)
    
    
    def root_mean_squared_error(y_true, y_predict):
        """计算y_true和y_predict之间的RMSE"""
    
        return sqrt(mean_squared_error(y_true, y_predict))
    
    
    def mean_absolute_error(y_true, y_predict):
        """计算y_true和y_predict之间的RMSE"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum(np.absolute(y_true - y_predict)) / len(y_true)

    计算最终模型预测准确度R方

    def r2_score(y_true, y_predict):
        """计算y_true和y_predict之间的R Square"""
    
        return 1 - mean_squared_error(y_true, y_predict)/np.var(y_true)
  • 相关阅读:
    detectron源码阅读,avgpool是什么意思、特征意思
    detectron2 特征金字塔代码讲解,detectron2train的过程
    detectron2 一个整体的demo说明自定义数据集以及训练过程
    fcos学习问题
    直接debug在源码上修改的方式
    为什么如果待分类类别是10个,类别范围可以设置成(9,11)
    delphi 各版本下载
    oracle11g 修改字符集 修改为ZHS16GBK
    delphi 判断调试状态
    fastreport5 有变量时 不加引号字符串出错 提示没有声明的变量
  • 原文地址:https://www.cnblogs.com/Erick-L/p/9031064.html
Copyright © 2011-2022 走看看