zoukankan      html  css  js  c++  java
  • GBDT算法用于分类问题

    转自:https://zhuanlan.zhihu.com/p/46445201

    GBDT算法概述

    GBDT是boosting算法的一种,按照boosting的思想,在GBDT算法的每一步,用一棵决策树去拟合当前学习器的残差,获得一个新的弱学习器。将这每一步的决策树组合起来,就得到了一个强学习器。

    具体来说,假设有训练样本 [公式] ,第m-1步获得的集成学习器为 [公式] ,那么GBDT通过下面的递推式,获得一个新的弱学习器 [公式]

    [公式]

    其中 [公式] 是在函数空间 [公式] 上最小化损失函数,一般来说这是比较难以做到的。但是,如果我们只考虑精确地拟合训练数据的话,可以将损失函数 [公式] 看作向量 [公式] 上的函数。这样在第m-1轮迭代之后,向量位于 [公式] ,如果我们想进一步减小损失函数,则根据梯度下降法,向量移动的方向应为损失函数的负梯度方向,即:

    [公式]

    这样如果使用训练集: [公式] 去训练一棵树的话,就相当于朝着损失函数减小的方向又走了一步(当然在实际应用中需要shrinkage,也就是考虑学习率)。由此可见,GBDT在本质上还是梯度下降法,每一步通过学习一棵拟合负梯度(也就是所谓“残差”)的树,来使损失函数逐渐减小。

    GBDT用于分类问题

    将GBDT应用于回归问题,相对来说比较容易理解。因为回归问题的损失函数一般为平方差损失函数,这时的残差,恰好等于预测值与实际值之间的差值。每次拿一棵决策树去拟合这个差值,使得残差越来越小,这个过程还是比较intuitive的。而将GBDT用于分类问题,则显得不那么显而易见。下面我们就通过一个简单的二分类问题,去看看GBDT究竟是如何学习到一棵树的。

    类似于逻辑回归、FM模型用于分类问题,其实是在用一个线性模型或者包含交叉项的非线性模型,去拟合所谓的对数几率 [公式] 。而GBDT也是一样,只是用一系列的梯度提升树去拟合这个对数几率,实际上最终得到的是一系列CART回归树。其分类模型可以表达为:

    [公式]

    其中[公式] 就是学习到的决策树。

    清楚了这一点之后,我们便可以参考逻辑回归,单样本 [公式] 的损失函数可以表达为交叉熵:

    [公式]

    假设第k步迭代之后当前学习器为 [公式] ,将 [公式] 的表达式带入之后, 可将损失函数写为:

    [公式]

    可以求得损失函数相对于当前学习器的负梯度为:

    [公式]

    可以看到,同回归问题很类似,下一棵决策树的训练样本为: [公式] ,其所需要拟合的残差为真实标签与预测概率之差。于是便有下面GBDT应用于二分类的算法:

    • [公式] ,其中 [公式] 是训练样本中y=1的比例,利用先验信息来初始化学习器
    • For [公式]
      • 计算 [公式] ,并使用训练集 [公式] 训练一棵回归树 [公式] ,其中 [公式]
      • 通过一维最小化损失函数找到树的最优权重: [公式]
      • 考虑shrinkage,可得这一轮迭代之后的学习器 [公式][公式] 为学习率
    • 得到最终学习器为: [公式]

    以上就是将GBDT应用于二分类问题的算法流程。类似地,对于多分类问题,则需要考虑以下softmax模型:

    [公式]

    [公式]

    [公式]

    [公式]

    其中 [公式] [公式][公式] 个不同的tree ensemble。每一轮的训练实际上是训练了 [公式] 棵树去拟合softmax的每一个分支模型的负梯度。softmax模型的单样本损失函数为:

    [公式]

    这里的 [公式] 是样本label在k个类别上作one-hot编码之后的取值,只有一维为1,其余都是0。由以上表达式不难推导:

    [公式]

    可见,这k棵树同样是拟合了样本的真实标签与预测概率之差,与二分类的过程非常类似。

  • 相关阅读:
    Linux修改主机名称方法
    高精度模板(含加减乘除四则运算)
    背包问题(0-1背包,完全背包,多重背包知识概念详解)
    [Swust OJ 385]--自动写诗
    [Swust OJ 403]--集合删数
    [Swust OJ 409]--小鼠迷宫问题(BFS+记忆化搜索)
    [Swust OJ 360]--加分二叉树(区间dp)
    [Swust OJ 402]--皇宫看守(树形dp)
    [Swust OJ 581]--彩色的石子(状压dp)
    [Swust OJ 589]--吃西瓜(三维矩阵压缩)
  • 原文地址:https://www.cnblogs.com/leebxo/p/12923933.html
Copyright © 2011-2022 走看看