zoukankan      html  css  js  c++  java
  • 使用catboost解决ML中高维度不平衡数据集挑战的解决方案

     

    python机器学习-乳腺癌细胞挖掘(博主亲自录制视频,包含catboost实战代码)

     

    https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campaign=commission&utm_source=cp-400000000398149&utm_medium=share

     

     

    https://blog.csdn.net/myboyliu2007/article/details/80256681

    什么是不平衡的数据集
    本文中所涉及的数据分类之间不均匀,有的几千条,有的几十条,这种数据集被认为是不平衡的数据集。 这些类型的数据集通常存在于预测维护系统,销售倾向,欺诈识别,智能诊断等应用领域中。

    例如,在电子病历中的呼吸科疾病诊断中,以2200条测试数据来看具体分布,数据分布很不均匀,最多的是1214例,最少的37例。各个疾病的分布如下。

    J20 急性支气管炎 97
    J15 肺炎 1214
    J45 哮喘 149
    J44 闭塞性细支气管炎 88
    A16 原发性肺结核 379
    A17 结核性脑膜炎 78
    T17 支气管异物 158
    J06 上感 37

    我们的模型大概设置了164个特征的不平衡数据集,所有可以认定为是一个高维不平衡的机器学习问题。 在原来的使用sklearn的逻辑回归算法时候,模型在样本数比较少的的分类中召回率很低,并且会错分类为样本数最多的分类,也就是说模型发生了过拟合。

    逻辑回归算法的模型性能
    一般ML算法在不平衡数据集上的性能:
    为了对catboost模型进行比较,我们使用了2种模型,一个是sklearn的逻辑回归,还有就是xgboost,使用这2种算法分别对数据进行了训练。

    逻辑回归:

    0.843858969
    [[ 167 35 20 2 2 1 22 0]
    [ 47 1022 19 13 32 9 18 1]
    [ 13 7 185 11 4 0 5 1]
    [ 0 5 9 211 3 0 0 0]
    [ 5 46 1 3 345 17 4 0]
    [ 2 4 0 0 5 83 0 0]
    [ 8 14 6 4 2 0 136 0]
    [ 0 0 0 1 0 0 2 29]]
    precision recall f1-score support

    J20 急性支气管炎 0.69 0.67 0.68 249
    J15 肺炎 0.90 0.88 0.89 1161
    J45 哮喘 0.77 0.82 0.79 226
    J44 闭塞性细支气管炎 0.86 0.93 0.89 228
    A16 原发性肺结核 0.88 0.82 0.85 421
    A17 结核性脑膜炎 0.75 0.88 0.81 94
    T17 支气管异物 0.73 0.80 0.76 170
    J06 上感 0.94 0.91 0.92 32

    avg / total 0.85 0.84 0.84 2581

    正如我们可以清楚地看到,模型预测急性支气管炎的能力很差(这种疾病样本数比较少)(Precision = 0.67 Recall = 0.67),整体模型精确度大约为85%。

    xgboost:

    0.865909090909
    [[ 22 37 8 2 3 0 10 0]
    [ 12 960 8 2 12 0 4 0]
    [ 5 18 68 6 2 0 2 0]
    [ 0 7 9 39 3 0 2 0]
    [ 1 43 2 0 256 3 4 0]
    [ 0 2 0 0 7 54 0 0]
    [ 1 13 3 0 2 0 97 0]
    [ 1 1 0 0 1 0 0 28]]
    precision recall f1-score support

    J20 急性支气管炎 0.52 0.27 0.35 82
    J15 肺炎 0.89 0.96 0.92 998
    J45 哮喘 0.69 0.67 0.68 101
    J44 闭塞性细支气管炎 0.80 0.65 0.72 60
    A16 原发性肺结核 0.90 0.83 0.86 309
    A17 结核性脑膜炎 0.95 0.86 0.90 63
    T17 支气管异物 0.82 0.84 0.83 116
    J06 上感 1.00 0.90 0.95 31

    avg / total 0.86 0.87 0.86 1760

    xgboost模型没有经过调参,性能比逻辑回归有了一定提升,但是对于少样本数据的预测结果变得更差了。急性支气管炎的结果为:(Precision = 0.52 Recall = 0.27),整体模型精确度大约为86%。

    可以通过下面一些技巧来提升小样本的分类效果
    以下是我们处理这个问题的方式,并且训练了一个预测精度92%的小样本的模型:

    减少特征的数量:
    更多特征的变量导致过度拟合数据,通过删除一些没有相关性的变量(需要专业知识)以及缺失值超过平均水平(本例考虑为15%)的变量。我们还计算了变量的重要性得分并据此删掉一些低重要性得分的特征变量。

    对小样本分类的数据进行oversample
    对于一般的不平衡问题我们采用oversample/undersample 的方法可能就能取得比较好的结果。但缺点就是如果overdample的比例过大会导致过拟合,undersample的比例过小可能会导致欠拟合。这个得经过实验去比较分析了。

    提高小样本分类的召回率:
    正如我们最初所看到的,由于在小样本的分类上会错误分类到大样本的分类上,所以我们试图找到一种可以解决小样本过拟合问题的模型。我们选择了一个非常强大的机器学习算法,叫做“catboost”,它对过拟合是鲁棒的。

    使用catboost
    由于在Kaggle的很多比赛中都是用了增强型机器学习算法,因此我们选择了Catboost算法,而且由于它对于小样本的过拟合问题解决的比较好,所以我们尝试使用这种算法。

    Catboost可以通过过拟合检测器来防止过度拟合模型,从而使得模型的泛化能力更好。 它基于一种与标准梯度下降方式的优化方法不同的方法,叫做Gradient Boosting。

    我们在没有调参的情况下,使用catboost训练后结果如下:

    0.833636364
    [[ 21 58 8 0 4 0 6 0]
    [ 6 1159 14 2 24 2 7 0]
    [ 2 21 108 9 5 0 4 0]
    [ 0 10 1 58 14 0 5 0]
    [ 0 65 3 4 301 4 2 0]
    [ 0 7 0 0 16 55 0 0]
    [ 2 43 7 3 1 0 102 0]
    [ 0 2 0 0 4 0 1 30]]
    precision recall f1-score support

    J20 急性支气管炎 0.68 0.22 0.33 97
    J15 肺炎 0.85 0.95 0.90 1214
    J45 哮喘 0.77 0.72 0.74 149
    J44 闭塞性细支气管炎 0.76 0.66 0.71 88
    A16 原发性肺结核 0.82 0.79 0.80 379
    A17 结核性脑膜炎 0.90 0.71 0.79 78
    T17 支气管异物 0.80 0.65 0.72 158
    J06 上感 1.00 0.81 0.90 37

    avg / total 0.83 0.83 0.82 2200

    最终的结果
    经过对超参数的调参,主要是学习率和迭代次数的调整,还有就是对小样本的数据进行oversample,最终得到一个相对理想的结果。

    ('error:', 0.7361311844473497)
    0.890072639
    [[190 6 6 0 1 0 3 0]
    [ 9 867 7 1 27 2 5 0]
    [ 3 11 162 7 1 0 4 0]
    [ 0 0 3 175 0 0 0 0]
    [ 2 52 1 2 264 9 1 0]
    [ 0 8 0 0 7 59 0 0]
    [ 3 26 5 7 4 0 98 0]
    [ 0 0 1 1 2 0 0 23]]
    precision recall f1-score support

    J20 急性支气管炎 0.92 0.92 0.92 206
    J15 肺炎 0.89 0.94 0.92 918
    J45 哮喘 0.88 0.86 0.87 188
    J44 闭塞性细支气管炎 0.91 0.98 0.94 178
    A16 原发性肺结核 0.86 0.80 0.83 331
    A17 结核性脑膜炎 0.84 0.80 0.82 74
    T17 支气管异物 0.88 0.69 0.77 143
    J06 上感 1.00 0.85 0.92 27

    avg / total 0.89 0.89 0.89 2065

     https://study.163.com/provider/400000000398149/index.htm?share=2&shareId=400000000398149( 欢迎关注博主主页,学习python视频资源,还有大量免费python经典文章)


    QQ:231469242

  • 相关阅读:
    Redis_数据类型
    python 单独设置在plot每条线的label为中文
    制作9patch图片心得——Android开发使用类似QQ聊天的冒泡对话框
    Oracle数据库实验一建立数据库
    Postman使用总结
    jmeter使用小结
    python实现系统调用cmd命令的模块---subprocess模块
    程序进程线程之间的区别
    Fiddler抓包工具简介
    MySQL基础SQL命令---增删改查
  • 原文地址:https://www.cnblogs.com/webRobot/p/10359031.html
Copyright © 2011-2022 走看看