zoukankan      html  css  js  c++  java
  • 一元线性回归-逐像素判断-多组同时运行-矩阵运算

    问题引入

    一元线性回归再简单不过了,实现的方式多种多样。调用scikit-learn linear_model.LinearRegression()、Scipy.polyfit( ) 或 numpy.polyfit( )、stats.linregress( )、optimize.curve_fit( )、numpy.linalg.lstsq、Statsmodels.OLS ( )使用矩阵求逆方法的解析解、高中数学讲的最小二乘法公式,详细请看博客:Python环境下的8种简单线性回归算法

    但如果数据Y(y1,y2,y3,...yn)中出现随机位置与个数的无效值(0值),只有将对应的X(x1,x2,x3,...xn)与Y中的无效值(0值)剔除,得到的Y随X变化趋势才是准确的。对于较少组Y与X的回归,可以采用for循环一组一组的判断计算。但是当有数以亿计组Y与X,用for循环则显得效率低下。遇到这一问题,是由于我在计算多年NDVI序列趋势,每一年的NDVI会出现随机位置与个数的无效值(0值)。

    经过探索,在健忘主义这篇博客启发下,找到了上述问题解决办法。下面分享编程实现思路

    线性回归函数

    通过改写健忘主义这篇博客,利用numpy 数组提供的求和平方等函数,设置操作在axis=0上对列进行求和等运算。通过数组运算提高操作效率。
    函数输入数组X,Y需要提前进行处理,在下节函数调用中说明。

    def lrg_ols(X,Y): 
        '''
        calculate slope and intercept by ols
        refer to https://blog.csdn.net/qq_35515661/article/details/84727471
    
        Parameters
        ----------
        X : 
            TYPE
                n*m ndarray
            DESCRIPTION.
                each column stands for a independent variable series, eg. x1,x2,x3,...,xn.
                which "x1,x2,x3,...,xn" is 1,2, 3 ...n insequence. X can be created by 
                X = np.arange(1,n+1).reshape(n,1)
                X = np.broadcast_to(X,(n,m))            
        Y : 
            TYPE
                n*m ndarray
            DESCRIPTION. 
                each column stands for a dependent variable series, eg. y1 y2 y3...yn.
                note the Y's dimensions must be the same as X's
    
        Returns
        -------
        k : TYPE
                1*m ndarray
            DESCRIPTION.
                corresponding slopes
            
        b : TYPE
                1*m ndarray
            DESCRIPTION.
                corresponding intercept
        '''
        
        X_size = np.count_nonzero(X, axis=0) # X_size should be equal Y_size
        # XYproduct = np.multiply(X,Y) #element-wise produt of X and Y
        # XXproduct = np.multiply(X,X)
        X_mean = np.sum(X,axis=0).reshape(1, -1) / X_size
        Y_mean = np.sum(Y,axis=0).reshape(1, -1) / X_size
        zi = np.sum(X*Y,axis=0).reshape(1, -1) - X_size * X_mean * Y_mean
        mu = np.sum(X*X,axis=0).reshape(1, -1) - X_size * X_mean * X_mean
        k = zi / mu
        b = Y_mean - k * X_mean
        
        # 计算决定系数
        Y_pred = X*k + b
        Y_pred[np.where(X==0)]=0   
        SSR = np.sum(np.square(Y_pred - Y_mean),axis=0).reshape(1, -1) # 回归平方和
        SSE = np.sum(np.square(Y - Y_pred),axis=0).reshape(1, -1) # 残差平方和
        SST = SSR + SSE # 总偏差平方和
        r2 = SSR / SST    
        return k,b,r2,X_size
    

    函数调用

    输入数组X,Y。因变量Y整理成二维数组,每一列是一个因变量时序;自变量整理成与X维度相同的数组,每一列是一个自变量数组。# 然后进行数据筛选,将因变量无效值处赋值为0,并将对应自变量处赋值为0。

    shp = data.shape        
    # reshape后成为每一列是一个时序
    data = data.reshape(shp[0], -1)
    # 构造自变量序列,1,2 ... n,
    X = np.arange(1,shp[0]+1).reshape(shp[0],1)
    X = np.repeat(X,shp[1]*shp[2],axis=1)
    # 数据筛选,将负值赋值为0
    cond = np.where(data<0)
    data[cond]=0
    X[cond]=0
    k,b,r2,xsize = lrg_ols(X,data)
    
  • 相关阅读:
    项目笔记三
    ASP.NET小收集<9>:HTML解析
    TSQL数据维护:更改表所有者
    [转贴]SQL2005:数据类型最大值
    TSQL存储过程:ROW_NUMBER()分页
    JS收集<7>:浏览器event兼容
    统计SQLServer2005表记录数
    ASP.NET小收集<8>:JS创建对象
    JS收集<8>:HTML控件的坐标
    MySql按指定天数进行分组数据统计分析 1
  • 原文地址:https://www.cnblogs.com/yhpan/p/14577963.html
Copyright © 2011-2022 走看看