zoukankan      html  css  js  c++  java
  • 线性回归 Python实现

     1 import numpy as np
     2 import pylab
     3 
     4 
     5 def plot_data(data, b, m):
     6     x = data[:, 0]
     7     y = data[:, 1]
     8     y_predict = m*x + b
     9     pylab.plot(x, y_predict, 'k-')
    10     pylab.plot(x, y, 'o')
    11     pylab.show()
    12 
    13 
    14 def gradient(data, initial_b, initial_m, learning_rate, num_iter):
    15     b = initial_b
    16     m = initial_m
    17     x = data[:, 0]
    18     y = data[:, 1]
    19     n = float(len(data))
    20     for i in range(num_iter):
    21         b_gradient = -(1/n)*(y - m*x-b)
    22         b_gradient = np.sum(b_gradient, axis=0)
    23         m_gradient = -(1/n)*x*(y - m*x - b)
    24         m_gradient = np.sum(m_gradient)
    25         theta0 = b - (learning_rate*b_gradient)
    26         theta1 = m - (learning_rate*m_gradient)
    27         b = theta0
    28         m = theta1
    29         if i % 100 == 0:
    30             j = (np.sum((y - m*x - b)**2))/n
    31             print("参数b:{},参数m:{},损失值:{}".format(b, m, j))
    32     return [b, m]
    33 
    34 
    35 def linear_regression():
    36     # 导入数据
    37     data = np.loadtxt('data.csv', delimiter=',')
    38 
    39     # 线性回归参数设定
    40     learning_rate = 0.001
    41     initial_b = 0.0
    42     initial_m = 0.0
    43     num_iter = 1000
    44 
    45     b, m = gradient(data, initial_b, initial_m, learning_rate, num_iter)
    46     plot_data(data, b, m)
    47 
    48 if __name__ == '__main__':
    49     linear_regression()

    测试结果:

  • 相关阅读:
    内存优化
    OpenThreads库学习
    WPS/office使用技巧系列
    NB-IOT学习
    JSON和XML
    物联网平台学习
    .net提供的5种request-response方法一
    HTML5之IndexedDB使用详解
    jQuery圆形统计图实战开发
    用javascript将数据导入Excel
  • 原文地址:https://www.cnblogs.com/reaptomorrow-flydream/p/9204870.html
Copyright © 2011-2022 走看看