zoukankan      html  css  js  c++  java
  • 线性回归(Linear Regression)的理解及原理

    线性回归的基本含义

      在统计学中,线性回归(Linear Regression)是利用称为线性回归方程的最小平方函数对一个或多个自变量因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。只有一个自变量的情况称为简单回归,大于一个自变量情况的叫做多元回归。(这反过来又应当由多个相关的因变量预测的多元线性回归区别,而不是一个单一的标量变量。)
      在线性回归中,数据使用线性预测函数来建模,并且未知的模型参数也是通过数据来估计。这些模型被叫做线性模型。最常用的线性回归建模是给定X值的y的条件均值是X的仿射函数。不太一般的情况,线性回归模型可以是一个中位数或一些其他的给定X的条件下y的条件分布的分位数作为X的线性函数表示。像所有形式的回归分析一样,线性回归也把焦点放在给定X值的y的条件概率分布,而不是X和y的联合概率分布(多元分析领域)。
      例如,房子的面积 x与房子的价格 y具有一定的线性关系,也就是说,我们可以画出能够大致表示 x与 y关系的一条直线,如下图:

      在该直线中,房子的面积 x为自变量,房子的价格 y为因变量。而“线性回归”的目的就是,利用自变量 x与因变量 y,来学习出这么一条能够描述两者之间关系的线对于一元线性回归来说就是学习出一条直线,而对于多元线性回归来说则是学习出一个超平面

    一元线性回归模型

    在概念上理解了线性回归是什么之后,我们就需要将线性回归的问题进行抽象化,转换成我们能够求解的数学问题。
    在上面的例子中,我们可以看出自变量 x与因变量 y致成线性关系,因此我们可以对因变量做如下假设(hypothesis):

    或者记作:

          其中 i=1,2,...,mi=1,2,...,m
       在这里使用是由于通过观察,我们可以发现直线并没有完全拟合数据,而是存在一定的误差。该假设即为一元线性函数的模型函数,其中含有两个参数。其中可视为斜率, 直线在y 轴上的截距。接下来的任务就是如何求得这两个未知参数。

    优化算法

    在这里我们将使用两种方法进行参数的求解:(1)最小二乘法(2)梯度下降法

    1、最小二乘法

      最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。最小二乘法还可用于曲线拟合。

     

     简单推导过程:

     2、梯度下降法

      如下图所示,将梯度下降的原理形象地描述为下山,直到获得一个局部或者全局最小值。在每次迭代中,根据给定的学习速率和梯度的斜率,能够确定每次移动的步幅,按照步幅沿着梯度方向前进一步。

    简单代码实现如下:

     1 class LinearModel(object):
     2     def __init__(self):
     3 #       初始值
     4         self.w = np.random.randn(1)[0]
     5         self.b = np.random.randn(1)[0]
     6         
     7     def model(self,x):#线性模型
     8         return self.w * x + self.b
     9     
    10     def cost(self,x,y):
    11         c = (self.model(x) - y)**2
    12 #       偏导数(包含导数) = 梯度
    13         dw = 2*(self.model(x) - y)*x
    14         db = 2*(self.model(x) - y)*1
    15         return c,dw,db
    16 #       梯度下降更新数据
    17     def gradient_descent(self,dw,db,step):
    18         self.w -= dw*step
    19         self.b -= db*step
    20         
    21     def fit(self,X,y):
    22         w_last = self.w + 0.1
    23         b_last = self.b + 0.1
    24         length = len(X)
    25         count = 1
    26         while True:
    27             if (abs(self.w - w_last) < 1e-4) and (abs(self.b - b_last) < 1e-4):
    28                 print('*********************')
    29                 break
    30             cost_ = 0
    31             derivative_w = 0
    32             derivative_b = 0
    33             for xi,yi in zip(X,y):
    34                 c_,dw_,db_ = self.cost(xi[0],yi)# 求解的是每一个样本的损失、w偏导、b偏导
    35                 cost_ += c_/length
    36                 derivative_w += dw_/length
    37                 derivative_b += db_/length
    38 #           cost_、derivative_w、derivative 整体的损失,偏导
    39             w_last = self.w
    40             b_last = self.b
    41 
    42             self.gradient_descent(derivative_w,derivative_b,0.01)
    43 
    44             print('--------------------------------梯度下降更新次数是:%d。损失是:%0.4f'%(count,cost_))
    45             count +=1
    46             
    47         print('++++++++++++++++++++++++++++++++++梯度下降计算的斜率是:%0.4f。计算的截距是:%0.4f'%(self.w,self.b))
    48 
    49 linear_model = LinearModel()
    50 linear_model.fit(X,y)

     

    线性回归算法的优缺点

    优点:

        (1)思想简单,实现容易。建模迅速,对于小数据量、简单的关系很有效;
        (2)是许多强大的非线性模型的基础。
        (3)线性回归模型十分容易理解,结果具有很好的可解释性,有利于决策分析。
        (4)蕴含机器学习中的很多重要思想。
        (5)能解决回归问题。

    缺点:

        (1)对于非线性数据或者数据特征间具有相关性多项式回归难以建模.
        (2)难以很好地表达高度复杂的数据。
  • 相关阅读:
    webdriver学习
    [Sqlite]-->Java使用jdbc连接Sqlite数据库进行各种数据操作的详细过程(转)
    java 二维码
    java 解析json超大文件(转)
    嵌套三目运算符
    实体的字段以is开头的教训
    easyui中formatter的使用
    springmvc中的controller是单例的
    hibernate 中baseservice中添加事物
    easyui中添加富文本编辑器
  • 原文地址:https://www.cnblogs.com/wuzc/p/12792942.html
Copyright © 2011-2022 走看看