zoukankan      html  css  js  c++  java
  • 线性回归

      1   
      2 # -*- coding: UTF-8 -*-
      3 """
      4 此脚本用于展示使用sklearn搭建线性回归模型
      5 """
      6 
      7 
      8 import os
      9 import sys
     10 
     11 import numpy as np
     12 import matplotlib.pyplot as plt
     13 import pandas as pd
     14 from sklearn import linear_model
     15 
     16 
     17 def evaluateModel(model, testData, features, labels):
     18     """
     19     计算线性模型的均方差和决定系数
     20     参数
     21     ----
     22     model : LinearRegression, 训练完成的线性模型
     23     testData : DataFrame,测试数据
     24     features : list[str],特征名列表
     25     labels : list[str],标签名列表
     26     返回
     27     ----
     28     error : np.float64,均方差
     29     score : np.float64,决定系数
     30     """
     31     # 均方差(The mean squared error),均方差越小越好
     32     error = np.mean(
     33         (model.predict(testData[features]) - testData[labels]) ** 2)
     34     # 决定系数(Coefficient of determination),决定系数越接近1越好
     35     score = model.score(testData[features], testData[labels])
     36     return error, score
     37 
     38 
     39 def visualizeModel(model, data, features, labels, error, score):
     40     """
     41     模型可视化
     42     """
     43     # 为在Matplotlib中显示中文,设置特殊字体
     44     plt.rcParams['font.sans-serif']=['SimHei']
     45     # 创建一个图形框
     46     fig = plt.figure(figsize=(6, 6), dpi=80)
     47     # 在图形框里只画一幅图
     48     ax = fig.add_subplot(111)
     49     # 在Matplotlib中显示中文,需要使用unicode
     50     # 在Python3中,str不需要decode
     51     if sys.version_info[0] == 3:
     52         ax.set_title(u'%s' % "线性回归示例")
     53     else:
     54         ax.set_title(u'%s' % "线性回归示例".decode("utf-8"))
     55     ax.set_xlabel('$x$')
     56     ax.set_ylabel('$y$')
     57     # 画点图,用蓝色圆点表示原始数据
     58     # 在Python3中,str不需要decode
     59     if sys.version_info[0] == 3:
     60         ax.scatter(data[features], data[labels], color='b',
     61             label=u'%s: $y = x + epsilon$' % "真实值")
     62     else:
     63         ax.scatter(data[features], data[labels], color='b',
     64             label=u'%s: $y = x + epsilon$' % "真实值".decode("utf-8"))
     65     # 根据截距的正负,打印不同的标签
     66     if model.intercept_ > 0:
     67         # 画线图,用红色线条表示模型结果
     68         # 在Python3中,str不需要decode
     69         if sys.version_info[0] == 3:
     70             ax.plot(data[features], model.predict(data[features]), color='r',
     71                 label=u'%s: $y = %.3fx$ + %.3f'
     72                 % ("预测值", model.coef_, model.intercept_))
     73         else:
     74             ax.plot(data[features], model.predict(data[features]), color='r',
     75                 label=u'%s: $y = %.3fx$ + %.3f'
     76                 % ("预测值".decode("utf-8"), model.coef_, model.intercept_))
     77     else:
     78         # 在Python3中,str不需要decode
     79         if sys.version_info[0] == 3:
     80             ax.plot(data[features], model.predict(data[features]), color='r',
     81                 label=u'%s: $y = %.3fx$ - %.3f'
     82                 % ("预测值", model.coef_, abs(model.intercept_)))
     83         else:
     84             ax.plot(data[features], model.predict(data[features]), color='r',
     85                 label=u'%s: $y = %.3fx$ - %.3f'
     86                 % ("预测值".decode("utf-8"), model.coef_, abs(model.intercept_)))
     87     legend = plt.legend(shadow=True)
     88     legend.get_frame().set_facecolor('#6F93AE')
     89     # 显示均方差和决定系数
     90     # 在Python3中,str不需要decode
     91     if sys.version_info[0] == 3:
     92         ax.text(0.99, 0.01, 
     93             u'%s%.3f
    %s%.3f'
     94             % ("均方差:", error, "决定系数:", score),
     95             style='italic', verticalalignment='bottom', horizontalalignment='right',
     96             transform=ax.transAxes, color='m', fontsize=13)
     97     else:
     98          ax.text(0.99, 0.01, 
     99             u'%s%.3f
    %s%.3f'
    100             % ("均方差:".decode("utf-8"), error, "决定系数:".decode("utf-8"), score),
    101             style='italic', verticalalignment='bottom', horizontalalignment='right',
    102             transform=ax.transAxes, color='m', fontsize=13)
    103     # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
    104     # 在Python shell里面,可以设置参数"block=False",使阻断失效。
    105     plt.show()
    106 
    107 
    108 def trainModel(trainData, features, labels):
    109     """
    110     利用训练数据,估计模型参数
    111     参数
    112     ----
    113     trainData : DataFrame,训练数据集,包含特征和标签
    114     features : 特征名列表
    115     labels : 标签名列表
    116     返回
    117     ----
    118     model : LinearRegression, 训练好的线性模型
    119     """
    120     # 创建一个线性回归模型
    121     model = linear_model.LinearRegression()
    122     # 训练模型,估计模型参数
    123     model.fit(trainData[features], trainData[labels])
    124     return model
    125 
    126 
    127 def linearModel(data):
    128     """
    129     线性回归模型建模步骤展示
    130     参数
    131     ----
    132     data : DataFrame,建模数据
    133     """
    134     features = ["x"]
    135     labels = ["y"]
    136     # 划分训练集和测试集
    137     trainData = data[:15]
    138     testData = data[15:]
    139     # 产生并训练模型
    140     model = trainModel(trainData, features, labels)
    141     # 评价模型效果
    142     error, score = evaluateModel(model, testData, features, labels)
    143     # 图形化模型结果
    144     visualizeModel(model, data, features, labels, error, score)
    145 
    146 
    147 def readData(path):
    148     """
    149     使用pandas读取数据
    150     """
    151     data = pd.read_csv(path)
    152     return data
    153 
    154 
    155 if __name__ == "__main__":    #主模块的名字是__main__,import的模块名字是自己
    156     homePath = os.path.dirname(os.path.abspath(__file__))  #os.path.dirname 是去掉文件名的路径 ,abspath获取当前文件路径
    157     # Windows下的存储路径与Linux并不相同
    158     if os.name == "nt":   #判断当前使用的平台,nt为windows
    159         dataPath = "%s\data\simple_example.csv" % homePath
    160     else:
    161         dataPath = "%s/data/simple_example.csv" % homePath
    162     data = readData(dataPath)
    163     linearModel(data)
    164 © 2019 GitHub, Inc.
    165 Terms
    166 Privacy
    167 Security
    168 Status
    169 Help
    170 Contact GitHub
    171 Pricing
    172 API
    173 Training
    174 Blog
    175 About
  • 相关阅读:
    切割栅格数据 切割raster
    缓存讲解
    Arcengine动态发布WMS
    dos命令
    在遥感影像中,立体相对观测的原理是什么?
    Top 10 steps to optimize data access in SQL Server: Part V (Optimize database files and apply partitioning)
    http://blog.csdn.net/itanders
    How to receive Deadlock information automatically via email
    减负
    Provisioning a New SQL Server Instance Series
  • 原文地址:https://www.cnblogs.com/bbgoal/p/10793527.html
Copyright © 2011-2022 走看看