zoukankan      html  css  js  c++  java
  • 多项式回归

    多项式回归

    若希望回归模型更好的拟合训练样本数据,可以使用多项式回归器。

    一元多项式回归

    y=w0 + w1 x + w2 x2 + w3 x3 + ... + wd xd

    将高次项看做对一次项特征的扩展得到:

    y=w0 + w1 x1 + w2 x2 + w3 x3 + ... + wd xd

    那么一元多项式回归即可以看做为多元线性回归,可以使用LinearRegression模型对样本数据进行模型训练。

    所以一元多项式回归的实现需要两个步骤:

    1. 将一元多项式回归问题转换为多元线性回归问题(只需给出多项式最高次数即可)。

    2. 将1步骤得到多项式的结果中 w1 w2 .. 当做样本特征,交给线性回归器训练多元线性模型。

    使用sklearn提供的数据管线实现两个步骤的顺序执行:

    import sklearn.pipeline as pl
    import sklearn.preprocessing as sp
    import sklearn.linear_model as lm
    
    model = pl.make_pipeline(
        sp.PolynomialFeatures(10),  # 多项式特征扩展器
        lm.LinearRegression())      # 线性回归器

    案例:

    import numpy as np
    import sklearn.pipeline as pl
    import sklearn.preprocessing as sp
    import sklearn.linear_model as lm
    import sklearn.metrics as sm
    import matplotlib.pyplot as mp
    # 采集数据
    x, y = np.loadtxt('../data/single.txt', delimiter=',', usecols=(0,1), unpack=True)
    x = x.reshape(-1, 1)
    # 创建模型(管线)
    model = pl.make_pipeline(
        sp.PolynomialFeatures(10),  # 多项式特征扩展器
        lm.LinearRegression())      # 线性回归器
    # 训练模型
    model.fit(x, y)
    # 根据输入预测输出
    pred_y = model.predict(x)
    test_x = np.linspace(x.min(), x.max(), 1000).reshape(-1, 1)
    pred_test_y = model.predict(test_x)
    mp.figure('Polynomial Regression', facecolor='lightgray')
    mp.title('Polynomial Regression', fontsize=20)
    mp.xlabel('x', fontsize=14)
    mp.ylabel('y', fontsize=14)
    mp.tick_params(labelsize=10)
    mp.grid(linestyle=':')
    mp.scatter(x, y, c='dodgerblue', alpha=0.75, s=60, label='Sample')
    mp.plot(test_x, pred_test_y, c='orangered', label='Regression')
    mp.legend()
    mp.show()

    过于简单的模型,无论对于训练数据还是测试数据都无法给出足够高的预测精度,这种现象叫做欠拟合。

    过于复杂的模型,对于训练数据可以得到较高的预测精度,但对于测试数据通常精度较低,这种现象叫做过拟合。

    一个性能可以接受的学习模型应该对训练数据和测试数据都有接近的预测精度,而且精度不能太低。

    训练集R2   测试集R2
    0.3        0.4        欠拟合:过于简单,无法反映数据的规则
    0.9        0.2        过拟合:过于复杂,太特殊,缺乏一般性
    0.7        0.6        可接受:复杂度适中,既反映数据的规则,同时又不失一般性
  • 相关阅读:
    42 【docker】run命令
    41 【docker】初识
    37 【kubernetes】搭建dashboard
    36 【kubernetes】coredns
    34 【kubernetes】安装手册
    35 【kubernetes】configMap
    33 【kebernetes】一个错误的解决方案
    25 【python入门指南】如何编写测试代码
    26【python】sprintf风格的字符串
    24 【python入门指南】class
  • 原文地址:https://www.cnblogs.com/maplethefox/p/11526367.html
Copyright © 2011-2022 走看看