zoukankan      html  css  js  c++  java
  • 机器学习十讲学习笔记第二讲

    矩阵的逆

    • 对于n*n方阵A,如果存在矩阵B使得AB = BA = I,则称BA的逆矩阵,记为A-1

    • 若A为可逆矩阵,则其逆矩阵是唯一的

    • 如何判断矩阵是否可逆?

      • 行列式不等于0
      • 满秩
      • 行(或列)向量组线性无关
      • ...

    回归:

    • 回归如今指的用一个或多个自变量来预测因变量的数学方法
    • 在机器学习中,回归指的是一类预测变量为连续值的有监督学习方法

    在回归模型中,需要预测的变量叫做因变量,用来解释因变量变化的变量叫做自变量

    一元线性回归:

    • 模型为y = w1x + w0,其中w0,w1为回归系数
    • 给定训练集D = {(x1,y1),...(xn,yn)},我们的目标是找到一条直线y = w1x + w0使得所有样本尽可能落在它的附近。

    多元线性回归:

    • y = w1x1 + w22+...+ wdxd +w0

    • 训练集D = {(x1,y1),...(xn,yn)}

    • ...

    • 假设训练集的特征部分记为n*(d+1)矩阵X,其中最后一列取值全为1

    • 标签部分记为...

    • 最小二乘的参数估计,如果变量之间存在较强的共线性,则XTX近似奇异,对参数的估计变得不准确,造成过度拟合现象。

    • 解决方法:正则化,主成分回归、偏最小二乘回归

    image-20210201235744504

    • 正则化可以减小线性回归的过度拟合和多重共线性等问题

    岭回归:

    • 岭回归:线性回归目标函数加上对W的惩罚函数λ||w||22λ||w||22
    • 线性回归目标函数:(Xwy)T(Xwy)(Xw−y)T(Xw−y)
    • 岭回归目标函数:(Xwy)T(Xwy)+λ||w||22(Xw−y)T(Xw−y)+λ||w||22
    • 对w求导并令导数等于零易得:xˆridge=(XTX+λI)1XTyx^ridge=(XTX+λI)−1XTy
    • 根据岭迹做超参数λλ的选择(岭迹分析)

    LASSO:

    • LASSO是一种系数压缩估计法,它的基本思想是通过追求稀疏性自动选择重要的变量
    • LASSO的目标函数:(Xwy)T(Xwy)+λ||w||1(Xw−y)T(Xw−y)+λ||w||1
    • LASSO的解xˆLASSOx^LASSO没有解析表达式,常用的求解算法包括坐标下降法、LARS算法和LSTA算法等

    回归模型的评价指标

    • 均方误差:MSE(y,yˆ)=1nni=1(yiyˆi)2MSE(y,y^)=1n∑i=1n(yi−y^i)2
    • 均方根误差:RMSE(y,yˆ)=1nni=1(yiyˆi)2−−−−−−−−−−−−−−√RMSE(y,y^)=1n∑i=1n(yi−y^i)2
    • 平均绝对误差:MAE(y,yˆ)=1nni=1|yiyˆi|MAE(y,y^)=1n∑i=1n|yi−y^i|
    • 决定系数:R2(y,yˆ)=1SSresSStot=1ni=1(yiyˆi)2ni=1(yiy¯i)2R2(y,y^)=1−SSresSStot=1−∑i=1n(yi−y^i)2∑i=1n(yi−y¯i)2

    案例:使用回归预测模型预测鲍鱼年龄

    1. 使用Python实现线性回归和岭回归算法,并与Sklearn中的实现进行对比
    2. 借助Sklearn工具,对线性回归、岭回归和LASSO三种模型的预测效果使用MAE和决定系数进行效果评估
    3. 残差图和正则化路径对模型表现进行分析
    • 读入数据:

      • 鲍鱼是一种贝类,在世界许多地方被认为是美味佳肴。是铁和泛酸的极好来源,也是澳大利亚、美洲和东亚的营养食品资源和农业。100克鲍鱼就能给人体提供超过 20% 上述每日摄入营养素。鲍鱼的经济价值与年龄正相关。因此,准确检测鲍鱼的年龄对养殖户和消费者确定鲍鱼的价格具有重要意义。

        然而,目前确定鲍鱼年龄的技术是相当昂贵和低效的。农场主通常把鲍鱼的壳割下来,用显微镜数鲍鱼环的数量,以估计鲍鱼的年龄。因此判断鲍鱼的年龄很困难,主要是因为鲍鱼的大小不仅取决于年龄,而且还取决于食物的供应情况。此外,鲍鱼有时会形成所谓的“发育不良”群体,这些群体的生长特性与其他鲍鱼种群有很大不同。这种复杂的方法增加了成本,限制了应用范围。本案例的目标是使用机器学习中的回归模型,找出最佳的指标来预测鲍鱼环数,进而预测鲍鱼的年龄。

        import pandas as pd
        import warnings
        warnings.filterwarnings('ignore')
        data = pd.read_csv("./input/abalone_dataset.csv")
        data.head()
        

        image-20210203104054282

      数据集一共有 4177 个样本,每个样本有 9 个特征。其中 rings 为鲍鱼环数,能够代表鲍鱼年龄,是预测变量。除了 sex 为离散特征,其余都为连续变量。

      首先借助 seaborn 中的 countplot 函数绘制条形图,观察 sex 列的取值分布情况。

    从以上连续特征之间的散点图我们可以看到一些基本的结果:

    • 例如从第一行可以看到鲍鱼的长度 length 和鲍鱼直径 diameter 、鲍鱼高度 height 存在明显的线性关系。鲍鱼长度与鲍鱼的四种重量之间存在明显的非线性关系。
    • 观察最后一行,鲍鱼环数 rings 与各个特征均存在正相关性,其中与 height 的线性关系最为直观。
    • 观察对角线上的直方图,可以看到幼鲍鱼(sex 取值为“I”)在各个特征上的取值明显小于其他成年鲍鱼。而雄性鲍鱼(sex 取值为“M”)和雌性鲍鱼(sex 取值为“F”)各个特征取值分布没有明显的差异。

    为了定量地分析特征之间的线性相关性,我们计算特征之间的相关系数矩阵,并借助热力图将相关性可视化。

    接下来实现线性回归和岭回归

    • 如果矩阵 XTXXTX 为满秩(行列式不为 0 ),则简单线性回归的解为 w^=(XTX)1XTyw^=(XTX)−1XTy 。实现一个函数 linear_regression,其输入为训练集特征部分和标签部分,返回回归系数向量。 我们借助 numpy 工具中的 np.linalg.det 函数和 np.linalg.inv 函数分别求矩阵的行列式和矩阵的逆。

    可见我们求得的模型为:

    y=-1.12×length+10×diameter+20.74×height+9.61×whole_weight20.05×shucked_weight12.07×viscera_weight+6.55×shell_weight0.88×sex_F+0.87×sex_M+4.32y=−1.12×length+10×diameter+20.74×height+9.61×whole_weight−20.05×shucked_weight−12.07×viscera_weight+6.55×shell_weight+0.88×sex_F+0.87×sex_M+4.32

    sklearn 中的 linear_model 模块实现了常见的线性模型,包括线性回归、岭回归和 LASSO 等。对应的算法和类名如下表所示。

    下面我们使用 LinearRegression 构建线性回归模型。注意,此时传给 fit 方法的训练集的特征部分不包括 ones 列。模型训练完成后,lr.coef_ 属性和 lr.intercept_ 属性分别保存了学习到的回归系数向量和截距项。

    岭回归的解为 w^Ridge=(XTX+λI)1XTyw^Ridge=(XTX+λI)−1XTy ,其中 λλ 为正则系数,II 为单位矩阵。我们实现 ridge_regression 函数来求解,它包括三个参数:训练集特征矩阵 X, 训练集标签向量 y,以及正则化次数 ridge_lambda

    单位矩阵可使用 np.eye 函数自动生成,其大小为 (d+1)(d+1),即与特征矩阵 X 的列数(X.shape[1])相同。

    在鲍鱼训练集上使用 ridge_regression 函数训练岭回归模型,正则化系数设置为 1 。

  • 相关阅读:
    关于 IIS 上运行 ASP.NET Core 站点的“HTTP 错误 500.19”错误
    下单快发货慢:一个 JOIN SQL 引起 SqlClient 读取数据慢的奇特问题
    ASP.NET Core 2.2 项目升级至 3.0 备忘录
    corefx 源码学习:SqlClient 是如何同步建立 Socket 连接的
    Chimee
    electron-vue:Vue.js 开发 Electron 桌面应用
    H5网页适配 iPhoneX,就是这么简单
    经典文摘:饿了么的 PWA 升级实践(结合Vue.js)
    Table Dragger
    分享8个网站开发中最好用的打印页面插件
  • 原文地址:https://www.cnblogs.com/52bb/p/14476844.html
Copyright © 2011-2022 走看看