zoukankan      html  css  js  c++  java
  • 线性回归——最小二乘法拟合

    记录下用最小二乘法拟合线性模型的代码实现:

    1、应用正规方程(Normal Equation)求解最小二乘法举例:

    最简单的y关于x的线性方程(y=eta_0+eta_1x)

    预测值和观察值:

     写成矩阵形式:

    (Xeta=y),其中(X=egin{bmatrix} 1& x_1\ 1& x_2\ ...&...\ 1& x_n\ end{bmatrix}) , (eta=egin{bmatrix} eta_0\ eta_1\ end{bmatrix}) , (y=egin{bmatrix} y_1\ y_2\ ...\ y_n\ end{bmatrix})

    最小二乘法的解(eta)可以通过解正规方程获得: (X^TXeta=X^Ty)

    python代码实现如下:

    import numpy as np
    from numpy.linalg import inv
    
    
    x=[0.015348106072082663,0.021715879765738008,0.0316253067336889,
       0.03212431406271639,0.03828189026158509,0.03965578961513808,
       0.05502733389871053,0.06116957740576664,0.06170785924013203,
       0.07206404835977503]
    y=[22, 0, 22, 11, 9, 31, 20, 31, 2, 20]
    
    
    def plotLinearRegression1():
        X = np.array(x).reshape(len(x),1)
        reg=linear_model.LinearRegression(fit_intercept=True,normalize=False)
        reg.fit(X,y)
        k=reg.coef_#获取斜率w1,w2,w3,...,wn
        b=reg.intercept_#获取截距w0
        #x0=np.arange(0,10,0.2)
        x0 = np.array(x).reshape(len(x),)
        y0=k*x0+b
        plt.scatter(x,y)
        plt.plot(x0, y0)
        print('k',k);
        print('b',b)
      
          
    '''
    用least squares
    '''
    def plotLinearRegression2():
        X = np.vstack([np.ones((1,len(x))),x]).T
        print('X:
    ',X)
        Xt = X.T
        Z = Xt.dot(X)
        invZ = inv(Z)
        print('invZ:
    ',invZ)
        
        W = (invZ.dot(Xt)).dot(y)
        print('W:
    ',W)
        
        #plot
        plt.scatter(x,y)
        x0 = np.array(x).reshape(len(x),)
       # x0=np.arange(0,10,0.2)
        plt.plot(x0, W[1]*x0+W[0])
        
        
    def plotLinearRegression3():
        A = np.vstack([x, np.ones(len(x))]).T
        print('A:
    ',A)
        m, c = np.linalg.lstsq(A, y, rcond=None)[0]
        print('m: ',m,'  c:',c)
       
        x0 = np.array(x).reshape(len(x),)
        plt.plot(x, y, 'o', label='Original data', markersize=10)
        plt.plot(x0, m*x0 + c, 'r', label='Fitted line')
        plt.legend()
        plt.show()  
        
            
    if __name__ == '__main__':
        plotLinearRegression2()

     

  • 相关阅读:
    事后诸葛亮
    OVS常用命令
    阿里云部署杂记
    Alpha冲刺总结
    测试随笔
    Alpha冲刺集合
    项目Alpha冲刺Day12
    项目Alpha冲刺Day11
    项目Alpha冲刺Day10
    MySQL修改密码
  • 原文地址:https://www.cnblogs.com/davidxu/p/14394963.html
Copyright © 2011-2022 走看看