zoukankan      html  css  js  c++  java
  • SimpleLinearRegression

     1 # coding: utf-8
     2 
     3 # In[3]:
     4 
     5 
     6 import numpy as np 
     7 import matplotlib.pyplot as plt
     8 
     9 def estimate_coefficients(x, y): 
    10     # size of the dataset OR number of observations/points 
    11     n = np.size(x) 
    12   
    13     # mean of x and y
    14     # Since we are using numpy just calling mean on numpy is sufficient 
    15     mean_x, mean_y = np.mean(x), np.mean(y) 
    16   
    17     # calculating cross-deviation and deviation about x 
    18     SS_xy = np.sum(y*x - n*mean_y*mean_x) 
    19     SS_xx = np.sum(x*x - n*mean_x*mean_x) 
    20   
    21     # calculating regression coefficients 
    22     b_1 = SS_xy / SS_xx 
    23     b_0 = mean_y - b_1*mean_x 
    24   
    25     return(b_0, b_1)
    26 
    27     # x,y are the location of points on graph
    28     # color of the points change it to red blue orange play around
    29 
    30 
    31 
    32 def plot_regression_line(x, y, b): 
    33     # plotting the points as per dataset on a graph
    34     plt.scatter(x, y, color = "m",marker = "o", s = 30) 
    35 
    36     # predicted response vector 
    37     y_pred = b[0] + b[1]*x 
    38   
    39     # plotting the regression line
    40     plt.plot(x, y_pred, color = "g")
    41   
    42     # putting labels for x and y axis
    43     plt.xlabel('Size') 
    44     plt.ylabel('Cost') 
    45   
    46     # function to show plotted graph
    47     plt.show()
    48     
    49 
    50     
    51 
    52 
    53 def main(): 
    54     # Datasets which we create 
    55     x = np.array([ 1,   2,   3,   4,   5,   6,   7,   8,    9,   10]) 
    56     y = np.array([300, 350, 500, 700, 800, 850, 900, 900, 1000, 1200]) 
    57   
    58     # estimating coefficients 
    59     b = estimate_coefficients(x, y) 
    60     print("Estimated coefficients:
    b_0 = {} 
    b_1 = {}".format(b[0], b[1])) 
    61   
    62     # plotting regression line 
    63     plot_regression_line(x, y, b)
    64 
    65     
    66 if __name__ == "__main__": 
    67     main()
  • 相关阅读:
    shell提交hive sql保存运行过程日志
    hive中 exists与left semi join
    hbase shell 导出数据转json
    ubuntu使用
    fast json
    elasticsearch 用户密码配置
    linux 自带php切换xampp
    Ubuntu查看crontab运行日志
    Linux服务器 XAMPP后添加PHP和MYSQL环境变量
    HBuilder 模拟器
  • 原文地址:https://www.cnblogs.com/liuwenhan/p/11790071.html
Copyright © 2011-2022 走看看