zoukankan      html  css  js  c++  java
  • 最小二乘法

    线性最小二乘法

    1. 原理:

    通过最小化误差的平方和寻找数据的最佳函数匹配,如图(点到直线距离最短)

     

    2. 求解

    假设直线为:

    最小二乘模型:

     求解:

     方程化:

     解得:

     

     代码实现:

    ###最小二乘实现
    import numpy as np
    import matplotlib.pyplot as plt
    
    def linear_least_squares(x, y):
        if not len(x) == len(y) :
            raise ValueError("横纵坐标数据个数不同")
        if len(x) == 1 or len(y) == 1:
            raise ValueError("传入坐标数为1")
        X, Y = np.array(x), np.array(y)
        XX, XY = X * X, X * Y
        sumX, sumY, sumXY, sumXX = sum(X), sum(Y), sum(XY), sum(XX)
        n = len(x)
        a = (sumXX*sumY - sumX*sumXY)/(n*sumXX - sumX**2)
        b = (n*sumXY - sumX*sumY)/(n*sumXX - sumX**2)
        s = n*a + b*sumX - sumY
        return a, b, s
    x_data = [1.12,2.62,2.99,4,5,6.52,7.31,8.62, 9.41, 10.58]
    y_data = [0.69*2, 1.99*2,3.41*2, 4*2, 5.12*2,6.01*2,6.98*2,8.11*2,9.01*2,10*2]
    a, b, s = linear_least_squares(x_data, y_data)
    print(s)
    x = np.linspace(0,11)
    y = x * b + a
    plt.title('s = {0:.4f}e-14'.format(s*(10**14)))
    plt.plot(x, y)
    plt.scatter(x_data, y_data)
    plt.show()

    结果

     

    推广到一般形式

    带入样本的得:

    改为矩阵

     

     

     其中:

     

     

     所以得到 :

     

    求解:

    1.矩阵推导:

     

     

    上述具体详见:https://zhuanlan.zhihu.com/p/24709748

     

     最优解:

     

    注:

     

     

    2.代数推导

     

     

     

    代码实现:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    def loadData(filename):
        fr = pd.read_excel(filename)
        data = np.mat(pd.concat([fr[0],fr[1],fr[2]], axis=1))
        label = np.mat(fr[3]).reshape(len(fr[3]),1)
        return data, label
    def linea_least_squares(x,y):
        xTx = np.dot(x.T,x)
        rel = np.dot(xTx.I, x.T)
        rel = np.dot(rel, y)
        return rel.tolist()
    x,y = loadData('C:/Users/shishenhao/Desktop/data.xlsx')
    rel = linea_least_squares(x, y)
    print(rel)

     另:线性与非线性与所给数据有关,如下

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    def loadData(filename):
        fr = pd.read_excel(filename)
        data = np.mat(pd.concat([fr[0],fr[1],fr[2]], axis=1))
        label = np.mat(fr[3]).reshape(len(fr[3]),1)
        return data, label
    def linea_least_squares(x,y):
        xTx = np.dot(x.T,x)
        rel = np.dot(xTx.I, x.T)
        rel = np.dot(rel, y)
        return rel.tolist()
    x, y = loadData('C:/Users/shishenhao/Desktop/data.xlsx')
    rel = linea_least_squares(x, y)
    x1 = np.linspace(0, 7)
    y1 = rel[0] + rel[1]*x1 + rel[2]*x1*x1
    plt.plot(x1, y1)
    plt.scatter(np.matrix.tolist(x[:,1]), np.matrix.tolist(y[:,0]))
    plt.show()

    直线和曲线的拟合与绘制(curve_fit()详解)

    https://blog.csdn.net/guduruyu/article/details/70313176

     

  • 相关阅读:
    Java中用JXL导出Excel代码详解
    oracle之FUNCTION拙见
    oracle 存储过程详细介绍(创建,删除存储过程,参数传递等)1
    Crontab使用方式
    Git使用技巧(3)-- 远程操作
    Git使用技巧(2)-- 基本操作
    Git使用技巧(1)-- 配置【持续更新】
    Vim使用技巧(1) -- 普通模式技巧 【持续更新】
    sublime使用技巧(4)-- 其他技巧【持续更新】
    sublime使用技巧(3)-- 常用快捷键【持续更新】
  • 原文地址:https://www.cnblogs.com/shish/p/12274646.html
Copyright © 2011-2022 走看看