zoukankan      html  css  js  c++  java
  • [机器学习]多变量线性回归

    多元线性回归

    一、回归模型

    • 多元线性回归的一般表达式

    [f(oldsymbol x_i)=oldsymbol {w^Tx_i}+b = egin{bmatrix} {w_{1}}&{w_{2}}&{cdots}&{w_{d}}\ end{bmatrix} egin{bmatrix} {x_{11}}\ {x_{21}}\ {vdots}\ {x_{d1}}\ end{bmatrix}+b ]

    • 其中将数据D表示为一个矩阵X,将参数b包含在举证中方便后面的运算

    [egin{align} X&= left[ egin{matrix} x_{11} & x_{12} & cdots & x_{1d} & 1 \ x_{21} & x_{22} & cdots & x_{2d} & 1 \ vdots & vdots & ddots & vdots & vdots \ x_{m1} & x_{m2} & cdots & x_{md} & 1 \ end{matrix} ight] = left[ egin{matrix} x^T_1 & 1 \ x^T_2 & 1 \ vdots & vdots \ x^T_m & 1 \ end{matrix} ight] = left[ egin{matrix} 样本1 & 1 \ 样本2 & 1 \ vdots & vdots \ 样本m & 1 \ end{matrix} ight] \ w^*&= left[ egin{matrix} w_1 \ w_2 \ vdots \ w_d \ b \ end{matrix} ight] \ y&=Xw^* end{align} ]

    • 方法类似于单变量线性回归,先写出误差项

      [e_i=y_i-f(x_i) ]

    • 误差累计和

    [egin{align} E &= sum^m_{i=1}e^2_i \ &= sum^m_{i=1}(y_i - left[ egin{matrix} x_i & 1 end{matrix} ight] w^*)^2\ &= ||y-Xw^*||^2_2 \ &= (y-Xw^*)^T(y-Xw^*) \ &= yy^T - (w^*)^TX^Ty -y^TXw^* + (w^*)^TX^TXw^* end{align} ]

    • 误差对w求导

    [egin{align} frac{partial E}{partial w^*} &=-X^Ty-X^Ty+2X^TXw^*=2X^T(Xw^*-y)=0 \ w^*&=(X^TX)^{-1}X^Ty end{align} ]

    • 即可得线性回归模型

    [f(x_i)= left[ egin{matrix} x_i & 1 end{matrix} ight] w^* = left[ egin{matrix} x_i & 1 end{matrix} ight] (X^TX)^{-1}X^Ty =(x^*_i)^T(X^TX)^{-1}X^Ty ]

    二、实例代码

    实例分析

    已知数据集

    特征1 特征2 输出值
    1 0.5 20.52
    2 2.6 27.53
    3 3.5 29.84
    4 6.7 36.92
    5 8.9 38.11
    6 9.2 37.21
    7 11.5 40.96
    8 15.8 46.37
    9 19.6 49.78
    10 20.1 50.22

    求解如下问题

    • 利用最小二乘求解多远线性模型参数,并计算训练的均方误差
    • 在得到模型的基础上,假如输入下面新数据,进行预测,并计算均方误差
    特征1 特征2 输出值
    11 0.85 52.38
    12 16.54 55.66
    13 3.37 58.31

    代码:Matlab

    x1 = [1 2 3 4 5 6 7 8 9 10]';
    x2 = [0.5 2.6 3.5 6.7 8.9 9.2 11.5 15.8 19.6 20.1]';
    y = [20.52 27.53 29.84 36.92 38.11 37.21 40.96 46.37 49.78 50.22]';
    X = [x1 x2 ones(10,1)]
    
    % 训练
    w_ = pinv(X'*X)*X'*y
    y_ = X*w_;
    
    % 预测
    x1_=[11 12 13]';
    x2_=[0.85 16.54 3.37]';
    y_new = [52.38 55.66 58.31]';
    y_new_ = [x1_ x2_ ones(3,1)]*w_
    
    %% 误差
    E = sum((y_-y).^2)/10
    E_t= sum((y_new_-y_new).^2)/3
    
    plot3(x1, x2, y_, 'r-', x1, x2, y, 'bo'); hold on
    legend('predict', 'real')
    

    运行结果

    w_ =
        1.6144
        0.6742
       22.2324
    y_new_ =
       40.5635
       52.7568
       45.4914
    E =
        3.4949
    E_t =
      104.1249
    

    代码:Python

    1.数据构造

    [egin{align} X &= left[ egin{matrix} x_{11} & x_{12} & cdots & x_{1d} & 1 \ x_{21} & x_{22} & cdots & x_{2d} & 1 \ vdots & vdots & ddots & vdots & vdots \ x_{m1} & x_{m2} & cdots & x_{md} & 1 \ end{matrix} ight] = left[ egin{matrix} x^T_1 & 1 \ x^T_2 & 1 \ vdots & vdots \ x^T_m & 1 \ end{matrix} ight] = left[ egin{matrix} 样本1 & 1 \ 样本2 & 1 \ vdots & vdots \ 样本m & 1 \ end{matrix} ight]\ end{align} ]

    import numpy as np
    import math
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    x1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape(10,1)
    x2 = np.array([0.5, 2.6, 3.5, 6.7, 8.9, 9.2, 11.5, 15.8, 19.6, 20.1]).reshape(10,1)
    y  = np.array([20.52, 27.53, 29.84, 36.92, 38.11, 37.21, 40.96, 46.37, 49.78, 50.22]).reshape(10,1)
    
    X = np.hstack((x1, x2, np.ones((10,1),dtype=int)))
    print('X:
    ',X)
    

    结果:

    X:
     [[ 1.   0.5  1. ]
     [ 2.   2.6  1. ]
     [ 3.   3.5  1. ]
     [ 4.   6.7  1. ]
     [ 5.   8.9  1. ]
     [ 6.   9.2  1. ]
     [ 7.  11.5  1. ]
     [ 8.  15.8  1. ]
     [ 9.  19.6  1. ]
     [10.  20.1  1. ]]
    

    2.参数训练

    [w^*= left[ egin{matrix} w_1 \ w_2 \ vdots \ w_m \ b \ end{matrix} ight] =(X^TX)^{-1}X^Ty ]

    w_ = np.dot(np.dot(np.linalg.pinv(np.dot(np.transpose(X),X)), np.transpose(X)) ,y)
    print('w:
    ',w_)
    

    结果:

    w:
     [[ 1.61436503]
     [ 0.67424588]
     [22.23241285]]
    

    3.预测

    [y=Xw^* ]

    y_ = np.dot(X, w_)
    print('y:
    ',y_)
    

    结果:

    y:
     [[24.18390082]
     [27.2141822 ]
     [29.43536853]
     [33.20732038]
     [36.30502636]
     [38.12166515]
     [41.28679571]
     [45.80041804]
     [49.97691742]
     [51.92840539]]
    

    4.计算均方误差

    [E = frac{1}{m}sum^m_{i=1}(y_i-f(x_i))^2 ]

    E = np.sum(pow((y - y_), 2))/len(x1)
    print('E:', E)
    

    5.绘图

    x1_T = x1.ravel()
    x2_T = x2.ravel()
    y_T  = y.ravel()
    y__T = y_.ravel()
    
    fig = plt.figure()
    ax = Axes3D(fig)
    
    ax.scatter(x1_T, x2_T, y_T, label='real', alpha=1)
    ax.plot(x1_T, x2_T, y__T, c='r', label='predict')
    ax.legend()  
    
    plt.show()
    

    结果:

  • 相关阅读:
    《AI for Game Developers》第七章 A*路径寻找算法 (二)(skiplow翻译)
    芯片科普学习笔记
    sprboot 配置logback 日志输出
    springboot+mybatis 配置双数据源(mysql,oracle,sqlserver,db2)
    vue 封装axios请求
    *arg参数
    pytest mac安装了pytest,但是输入pytest却提示命令不存在
    构建Java Web开发环境
    在CentOS上编译安装PostgreSQL
    在Ubuntu 14.04上使用Eclipse开发和调试PosgreSQL9.3.4
  • 原文地址:https://www.cnblogs.com/zou107/p/12520820.html
Copyright © 2011-2022 走看看