zoukankan      html  css  js  c++  java
  • SimpleLinearRegression

    import numpy as np
    from .metrics import r2_score


    class SimpleLinearRegression:

    def __init__(self):
    """初始化Simple Linear Regression模型"""
    self.a_ = None
    self.b_ = None

    def fit(self, x_train, y_train):
    """根据训练数据集x_train, y_train训练Simple Linear Regression模型"""
    assert x_train.ndim == 1,
    "Simple Linear Regressor can only solve single feature training data."
    assert len(x_train) == len(y_train),
    "the size of x_train must be equal to the size of y_train"

    x_mean = np.mean(x_train)
    y_mean = np.mean(y_train)

    self.a_ = (x_train - x_mean).dot(y_train - y_mean) / (x_train - x_mean).dot(x_train - x_mean)
    self.b_ = y_mean - self.a_ * x_mean

    return self

    def predict(self, x_predict):
    """给定待预测数据集x_predict,返回表示x_predict的结果向量"""
    assert x_predict.ndim == 1,
    "Simple Linear Regressor can only solve single feature training data."
    assert self.a_ is not None and self.b_ is not None,
    "must fit before predict!"

    return np.array([self._predict(x) for x in x_predict])

    def _predict(self, x_single):
    """给定单个待预测数据x,返回x的预测结果值"""
    return self.a_ * x_single + self.b_

    def score(self, x_test, y_test):
    """根据测试数据集 x_test 和 y_test 确定当前模型的准确度"""

    y_predict = self.predict(x_test)
    return r2_score(y_test, y_predict)

    def __repr__(self):
    return "SimpleLinearRegression()"
  • 相关阅读:
    【jQuery】清空表单内容
    【jQuery】remove()和empty()的使用
    【ajax 提交表单】多种方式的注意事项 ,serialize()的使用
    【Gson】互相转化
    yum安装nginx详解
    linux find命令
    nginx实战
    java判断是否为汉字
    分布式存储 CentOS6.5虚拟机环境搭建FastDFS-5.0.5集群(转载)
    Java应用程序实现屏幕的"拍照"
  • 原文地址:https://www.cnblogs.com/heguoxiu/p/10135588.html
Copyright © 2011-2022 走看看