zoukankan      html  css  js  c++  java
  • 郑捷《机器学习算法原理与编程实践》学习笔记(第七章 预测技术与哲学)7.1 线性系统的预测

     7.1.1 回归与现代预测

     7.1.2 最小二乘法

     7.1.3 代码实现

    (1)导入数据

    def loadDataSet(self,filename):     #加载数据集
        X = [];Y = []
        fr = open(filename)
        for line in fr.readlines():
            curLine = line.strip().split('	')
            X.append(float(curLine[0]))
            Y.append(float(curLine[-1]))
        return X,Y
    (2)绘制图形函数
    # (2)绘制图形函数
    def plotscatter(Xmat,Ymat,a,b,plt):
        fig = plt.figure()
        ax  = fig.add_subplot(111) #绘制图形位置
        ax.scatter(Xmat,Ymat,c='blue',marker='o')#绘制散点图
        Xmat.sort()                #对Xmat元素进行排序
        yhat = [a.float(xi)+b for xi in Xmat] #计算预测值
        plt.plot(Xmat,yhat,'r')
        plt.show()
        return  yhat

    (3)主函数

    Xmat,Ymat = loadDataSet("regdataset.txt") #导入数据文件
    meanX     = mean(Xmat)                      #原始数据的均值
    meanY     = mean(Ymat)                      #原始数据的均值
    dX        = Xmat-meanX                      #各元素与均值的差
    dY        = Ymat-meanY                      #各元素与均值的差
    #手工计算
    # sumXY = 0;Sqx = 0
    # for i in xrange(len(dx)):
    #     sumXY += double(dx[i])*double(dy[i])
    #     Sqx   = double(dX[i])**2
    
    sumXY = vdot(dX,dY) #返回两个向量的点乘multiply
    Sqx   = sum(power(dX,2))#向量的平方:(X-meanX)^2
    
    #计算斜率和截距
    a = sumXY/Sqx
    b = meanY-a*meanX
    print a,b
    #绘制图形
    plotscatter(Xmat,Ymat,a,b,plt)
     7.1.4 正规方程组法
    7.1.5 正规方程组的代码实现
     
    #数据矩阵,分类标签
    xArr,yArr = loadDataSet("regdataset.txt") #导入数据文件
    
    m = len(xArr) #生成X坐标列
    Xmat = mat(ones((m,2)))
    for i in xrange(m):
        Xmat[i,1] = xArr[i]
    Ymat = mat(yArr).T   #转化为Y列
    xTx = Xmat.T*Xmat
    
    ws = [] #直线的斜率和截距
    if linalg.det(xTx) != 0.0:    #行列式不为0
        ws = linalg.inv(Xmat.T*Xmat)*(Xmat.T*Ymat)#矩阵的正规方程组的公式:inv(X.T*X)*X.T*Y
    else:
        print  u"矩阵为奇异阵,无逆矩阵"
        sys.exit(0)#退出程序
    print  "ws:",ws
     

    资料来源:郑捷《机器学习算法原理与编程实践》 仅供学习研究

  • 相关阅读:
    生成指定范围的随机数
    sql
    map的使用
    基础03 JVM到底在哪里?
    Elasticsearch6.1.0 TransportClient聚合查询索引中所有数据
    Elasticsearch6.1.0 TransportClient滚动查询索引中所有数据写入文件中
    elasticsearch-java api中get() 和execute().actionGet()方法
    Elasticsearch6(Transport Client)常用操作
    Reflections反射获取注解下类
    Ambari2.6.0 安装HDP2.6.3: Python script has been killed due to timeout after waiting 300 secs
  • 原文地址:https://www.cnblogs.com/wuchuanying/p/6409176.html
Copyright © 2011-2022 走看看