#!/usr/bin/python # -*- coding:utf-8 -*- import csv import numpy as np import matplotlib.pyplot as plt import pandas as pd from sklearn.model_selection import train_test_split from sklearn.linear_model import LinearRegression if __name__ == "__main__": path = r"C:8.Advertising.csv" # # 手写读取数据 - 请自行分析,在8.2.Iris代码中给出类似的例子 # f = file(path) # x = [] # y = [] # for i, d in enumerate(f): # if i == 0: # continue # d = d.strip() # if not d: # continue # d = map(float, d.split(',')) # x.append(d[1:-1]) # y.append(d[-1]) # print x # print y # x = np.array(x) # y = np.array(y) # # Python自带库 # f = file(path, 'rb') # print f # d = csv.reader(f) # for line in d: # print line # f.close() # # numpy读入 # p = np.loadtxt(path, delimiter=',', skiprows=1) # print p # pandas读入 data = pd.read_csv(path) x = data[['TV', 'Radio', 'Newspaper']] x = data[['TV', 'Radio']] y = data['Sales'] print(x) print(y) # # # 绘制1 # plt.plot(data['TV'], y, 'ro', label='TV') # plt.plot(data['Radio'], y, 'g^', label='Radio') # plt.plot(data['Newspaper'], y, 'mv', label='Newspaer') # plt.legend(loc='lower right') # plt.grid() # plt.show() # # 绘制2 plt.figure(figsize=(9,12)) plt.subplot(311) plt.plot(data['TV'], y, 'ro') plt.title('TV') plt.grid() plt.subplot(312) plt.plot(data['Radio'], y, 'g^') plt.title('Radio') plt.grid() plt.subplot(313) plt.plot(data['Newspaper'], y, 'b*') plt.title('Newspaper') plt.grid() plt.tight_layout() plt.show() #一部分用于训练数据,一部分用于测试数据 x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1) # print x_train, y_train #线性回归 linreg = LinearRegression() #进行拟合 model = linreg.fit(x_train, y_train) print(model) #系数 print(linreg.coef_) #截距 print(linreg.intercept_) #进行验证 #预测值 y_hat = linreg.predict(np.array(x_test)) #误差 mse = np.average((y_hat - np.array(y_test)) ** 2) # Mean Squared Error rmse = np.sqrt(mse) # Root Mean Squared Error print(mse, rmse) # t = np.arange(len(x_test)) plt.plot(t, y_test, 'r-', linewidth=2, label='Test') plt.plot(t, y_hat, 'g-', linewidth=2, label='Predict') plt.legend(loc='upper right') plt.grid() plt.show()