zoukankan      html  css  js  c++  java
  • 11.线性回归的调用

    #!/usr/bin/python
    # -*- coding:utf-8 -*-
    
    import csv
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LinearRegression
    
    
    if __name__ == "__main__":
        path = r"C:8.Advertising.csv"
        # # 手写读取数据 - 请自行分析,在8.2.Iris代码中给出类似的例子
        # f = file(path)
        # x = []
        # y = []
        # for i, d in enumerate(f):
        #     if i == 0:
        #         continue
        #     d = d.strip()
        #     if not d:
        #         continue
        #     d = map(float, d.split(','))
        #     x.append(d[1:-1])
        #     y.append(d[-1])
        # print x
        # print y
        # x = np.array(x)
        # y = np.array(y)
    
        # # Python自带库
        # f = file(path, 'rb')
        # print f
        # d = csv.reader(f)
        # for line in d:
        #     print line
        # f.close()
    
        # # numpy读入
        # p = np.loadtxt(path, delimiter=',', skiprows=1)
        # print p
    
        # pandas读入
        data = pd.read_csv(path)
        x = data[['TV', 'Radio', 'Newspaper']]
        x = data[['TV', 'Radio']]
        y = data['Sales']
        print(x)
        print(y)
    
        # # # 绘制1
        # plt.plot(data['TV'], y, 'ro', label='TV')
        # plt.plot(data['Radio'], y, 'g^', label='Radio')
        # plt.plot(data['Newspaper'], y, 'mv', label='Newspaer')
        # plt.legend(loc='lower right')
        # plt.grid()
        # plt.show()
    
        # # 绘制2
        plt.figure(figsize=(9,12))
        plt.subplot(311)
        plt.plot(data['TV'], y, 'ro')
        plt.title('TV')
        plt.grid()
        plt.subplot(312)
        plt.plot(data['Radio'], y, 'g^')
        plt.title('Radio')
        plt.grid()
        plt.subplot(313)
        plt.plot(data['Newspaper'], y, 'b*')
        plt.title('Newspaper')
        plt.grid()
        plt.tight_layout()
        plt.show()
    
        #一部分用于训练数据,一部分用于测试数据
        x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1)
        # print x_train, y_train
        #线性回归
        linreg = LinearRegression()
        #进行拟合
        model = linreg.fit(x_train, y_train)
        print(model)
        #系数
        print(linreg.coef_)
        #截距
        print(linreg.intercept_)
        #进行验证
        #预测值
        y_hat = linreg.predict(np.array(x_test))
        #误差
        mse = np.average((y_hat - np.array(y_test)) ** 2)  # Mean Squared Error
        rmse = np.sqrt(mse)  # Root Mean Squared Error
        print(mse, rmse)
        #
        t = np.arange(len(x_test))
        plt.plot(t, y_test, 'r-', linewidth=2, label='Test')
        plt.plot(t, y_hat, 'g-', linewidth=2, label='Predict')
        plt.legend(loc='upper right')
        plt.grid()
        plt.show()
  • 相关阅读:
    Latin1的所有字符编码
    Qt自定义委托在QTableView中绘制控件、图片、文字(内容比较全)
    Delphi5 update1的序列号
    Access Violation at address 00000000.Read of address 00000000 解决办法
    RealThinClient学习(一)
    使用WebDriver遇到的那些坑
    谱聚类(Spectral Clustering)详解
    Asp.net Mvc4默认权限详细(上)
    ASP.NET Web API中的JSON和XML序列化
    [珠玑之椟]估算的应用与Little定律
  • 原文地址:https://www.cnblogs.com/xiaochi/p/11264755.html
Copyright © 2011-2022 走看看