zoukankan      html  css  js  c++  java
  • 《ML模型超参数调节:网格搜索、随机搜索与贝叶斯优化》

    ML模型超参数调节:网格搜索、随机搜索与贝叶斯优化


    0.3192018.09.19 15:24:47字数 1,411阅读 8,013

    之前一直在阿里实习,最近终于闲了下来参加了一个Kaggle的比赛,记录一下比赛过程中对模型调参的一些经验。

    在进行机器学习的过程中,最为核心的一个概念就是参数,而参数又分为模型参数与超参数。模型参数,顾名思义就是我们使用的模型根据训练数据的分布学习到的参数,这一部分不需要我们人为的先验经验。超参数是在开始学习过程之前设置值的参数,而不是通过训练得到的参数数据。通常情况下,需要对超参数进行优化,给模型选择一组最优超参数,以提高学习的性能和效果。通常情况下,常用的超参数调参的方法有:网格搜索,随机搜索与贝叶斯优化。

    在下文我们以Kaggle中最常用的模型LightGBM与Google Analytics Customer Revenue Prediction比赛数据为例对这三种方法进行探索,最终我在比赛中采用的是贝叶斯优化。

    网格搜索:

    网格搜索是应用最广泛的超参数搜索算法,网格搜索通过查找搜索范围内的所有的点,来确定最优值。一般通过给出较大的搜索范围以及较小的步长,网格搜索是一定可以找到全局最大值或最小值的。但是,网格搜索一个比较大的问题是,它十分消耗计算资源,特别是需要调优的超参数比较多的时候。在比赛中,需要调参的模型数量与对应的超参数比较多,而涉及的数据量又比较大,因此相当的耗费时间。此外,由于给出的超参数组合比较多,因此一般都会固定多数参数,分步对1~2个超参数进行调解,这样能够减少时间但是缺难以自动化进行,而且由于目标参数一般是非凸的,因此容易陷入局部最小值。

    网格搜索的方法如下:

    import lightgbm as lgb
    from sklearn.model_selection import GridSearchCV
    
    
    def GridSearch(clf, params, X, y):
        cscv = GridSearchCV(clf, params, scoring='neg_mean_squared_error', n_jobs=1, cv=5)
        cscv.fit(X, y)
    
        print(cscv.cv_results_)
        print(cscv.best_params_)
    
    
    if __name__ == '__main__':
        train_X, train_y = get_data()
    
        param = {
            'objective': 'regression',
            'n_estimators': 275,
            'max_depth': 6,
            'min_child_samples': 20,
            'reg_lambd': 0.1,
            'reg_alpha': 0.1,
            'metric': 'rmse',
            'colsample_bytree': 1,
            'subsample': 0.8,
            'num_leaves' : 40,
            'random_state': 2018
            }
        regr = lgb.LGBMRegressor(**param)
    
        adj_params = {'n_estimators': range(100, 400, 10),
                     'min_child_weight': range(3, 20, 2),
                     'colsample_bytree': np.arange(0.4, 1.0),
                     'max_depth': range(5, 15, 2),
                     'subsample': np.arange(0.5, 1.0, 0.1),
                     'reg_lambda': np.arange(0.1, 1.0, 0.2),
                     'reg_alpha': np.arange(0.1, 1.0, 0.2),
                     'min_child_samples': range(10, 30)}
    
        GridSearch(regr , adj_params , train_X, train_y)
    

    根据我们设定的超参数分布范围来看,对所有的参数组合进行一一尝试是不现实的,这可能会消耗数天甚至数星期的时间,尤其是在大样本训练集上。

    随机搜索:

    与网格搜索相比,随机搜索并未尝试所有参数值,而是从指定的分布中采样固定数量的参数设置。它的理论依据是,如果随即样本点集足够大,那么也可以找到全局的最大或最小值,或它们的近似值。通过对搜索范围的随机取样,随机搜索一般会比网格搜索要快一些。但是和网格搜索的快速版(非自动版)相似,结果也是没法保证的。

    随机搜索的过程如下,使用方法与网格搜索完全一致:

    import lightgbm as lgb
    from sklearn.model_selection import  RandomizedSearchCV
    
    
    def RandomSearch(clf, params, X, y):
        rscv = RandomizedSearchCV(clf, params, scoring='neg_mean_squared_error', n_jobs=1, cv=5)
        rscv.fit(X, y)
    
        print(rscv.cv_results_)
        print(rscv.best_params_)
    
    
    if __name__ == '__main__':
        train_X, train_y = get_data()
    
        param = {
            'objective': 'regression',
            'n_estimators': 275,
            'max_depth': 6,
            'min_child_samples': 20,
            'reg_lambd': 0.1,
            'reg_alpha': 0.1,
            'metric': 'rmse',
            'colsample_bytree': 1,
            'subsample': 0.8,
            'num_leaves' : 40,
            'random_state': 2018
            }
        regr = lgb.LGBMRegressor(**param)
    
        adj_params = {'n_estimators': range(100, 400, 10),
                     'min_child_weight': range(3, 20, 2),
                     'colsample_bytree': np.arange(0.4, 1.0),
                     'max_depth': range(5, 15, 2),
                     'subsample': np.arange(0.5, 1.0, 0.1),
                     'reg_lambda': np.arange(0.1, 1.0, 0.2),
                     'reg_alpha': np.arange(0.1, 1.0, 0.2),
                     'min_child_samples': range(10, 30)}
    
        RandomSearch(regr , adj_params , train_X, train_y)
    

    贝叶斯优化:

    贝叶斯优化用于机器学习调参由J. Snoek(2012)提出,主要思想是,给定优化的目标函数(广义的函数,只需指定输入和输出即可,无需知道内部结构以及数学性质),通过不断地添加样本点来更新目标函数的后验分布(高斯过程,直到后验分布基本贴合于真实分布。简单的说,就是考虑了上一次参数的信息,从而更好的调整当前的参数。

    贝叶斯优化与常规的网格搜索或者随机搜索的区别是:

    1.贝叶斯调参采用高斯过程,考虑之前的参数信息,不断地更新先验;网格搜索未考虑之前的参数信息。
    2.贝叶斯调参迭代次数少,速度快;网格搜索速度慢,参数多时易导致维度爆炸。
    3.贝叶斯调参针对非凸问题依然稳健;网格搜索针对非凸问题易得到局部优最。

    贝叶斯优化调参的具体原理可以参考:拟合目标函数后验分布的调参利器:贝叶斯优化

    我们使用BayesOpt包来进行贝叶斯优化调参,安装命令如下所示:

    pip install bayesian-optimization
    

    BayesOpt包主要使用BayesianOptimization函数来创建一个优化对象,该函数接受一个模型评估函数function,这个function的输入应该是xgboost(或者其他ML模型)的超参数,输出是模型在测试集上的效果(可以是Accuracy,也可以是RMSE,取决于具体的任务,一般返回K-Fold的均值)。

    基于5-Fold的LightGBM贝叶斯优化的过程如下所示:

    import lightgbm as lgb
    from bayes_opt import BayesianOptimization
    
    train_X, train_y = None, None
    
    
    def BayesianSearch(clf, params):
        """贝叶斯优化器"""
        # 迭代次数
        num_iter = 25
        init_points = 5
        # 创建一个贝叶斯优化对象,输入为自定义的模型评估函数与超参数的范围
        bayes = BayesianOptimization(clf, params)
        # 开始优化
        bayes.maximize(init_points=init_points, n_iter=num_iter)
        # 输出结果
        params = bayes.res['max']
        print(params['max_params'])
        
        return params
    
    
    def GBM_evaluate(min_child_samples, min_child_weight, colsample_bytree, max_depth, subsample, reg_alpha, reg_lambda):
        """自定义的模型评估函数"""
    
        # 模型固定的超参数
        param = {
            'objective': 'regression',
            'n_estimators': 275,
            'metric': 'rmse',
            'random_state': 2018}
    
        # 贝叶斯优化器生成的超参数
        param['min_child_weight'] = int(min_child_weight)
        param['colsample_bytree'] = float(colsample_bytree),
        param['max_depth'] = int(max_depth),
        param['subsample'] = float(subsample),
        param['reg_lambda'] = float(reg_lambda),
        param['reg_alpha'] = float(reg_alpha),
        param['min_child_samples'] = int(min_child_samples)
    
        # 5-flod 交叉检验,注意BayesianOptimization会向最大评估值的方向优化,因此对于回归任务需要取负数。
        # 这里的评估函数为neg_mean_squared_error,即负的MSE。
        val = cross_val_score(lgb.LGBMRegressor(**param),
            train_X, train_y ,scoring='neg_mean_squared_error', cv=5).mean()
    
        return val
    
    
    if __name__ == '__main__':
        # 获取数据,这里使用的是Kaggle比赛的数据
        train_X, train_y = get_data()
        # 调参范围
        adj_params = {'min_child_weight': (3, 20),
                     'colsample_bytree': (0.4, 1),
                     'max_depth': (5, 15),
                     'subsample': (0.5, 1),
                     'reg_lambda': (0.1, 1),
                     'reg_alpha': (0.1, 1),
                     'min_child_samples': (10, 30)}
        # 调用贝叶斯优化
        BayesianSearch(GBM_evaluate, adj_params)
    

    迭代25次的优化结果如下所示:

    首先BayesianOptimization进行多次随机采样进行初始化,得到一个超参数与误差的分布结果,然后在这个结果的基础上使用贝叶斯优化来逼近最优超参数的分布。可以看出在所有的迭代结果中,第25次的结果最好,5-fold的MSE为2.64087。

    Initialization
    -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
     Step |   Time |      Value |   colsample_bytree |   max_depth |   min_child_samples |   min_child_weight |   n_estimators |   reg_alpha |   reg_lambda |   subsample | 
        1 | 00m32s |   -2.65636 |             0.6084 |     12.3355 |             26.6139 |             6.9177 |       337.4966 |      0.7969 |       0.1272 |      0.5945 | 
        2 | 00m29s |   -2.66585 |             0.4792 |      9.6159 |             13.1645 |            11.0249 |       372.2184 |      0.4597 |       0.1045 |      0.9052 | 
        3 | 00m30s |   -2.66461 |             0.4438 |      6.9836 |             12.0662 |            10.1247 |       378.3518 |      0.4865 |       0.8916 |      0.5287 | 
        4 | 00m19s |   -2.64282 |             0.8409 |     12.0801 |             20.8223 |            19.0301 |       165.1360 |      0.5061 |       0.5769 |      0.6494 | 
        5 | 00m23s |   -2.65333 |             0.5053 |      9.6624 |             27.2682 |            14.3314 |       254.0202 |      0.9768 |       0.1583 |      0.9284 | 
    Bayesian Optimization
    -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
     Step |   Time |      Value |   colsample_bytree |   max_depth |   min_child_samples |   min_child_weight |   n_estimators |   reg_alpha |   reg_lambda |   subsample | 
        6 | 00m33s |   -2.66755 |             0.5022 |      8.5932 |             29.5752 |             3.1476 |       100.0287 |      0.8609 |       0.4528 |      0.6313 | 
        7 | 00m43s |   -2.65496 |             0.4501 |      5.7514 |             29.5304 |            18.6252 |       399.8419 |      0.7158 |       0.9874 |      0.8430 | 
        8 | 00m24s |   -2.67955 |             0.6281 |      5.2346 |             10.2421 |            19.5566 |       101.5726 |      0.3053 |       0.7435 |      0.9929 | 
        9 | 00m28s |   -2.65857 |             0.4000 |     15.0000 |             10.0000 |             3.0000 |       182.3556 |      1.0000 |       1.0000 |      0.5000 | 
       10 | 00m50s |   -2.66609 |             0.6112 |     13.3527 |             29.0815 |             3.6173 |       399.9207 |      0.2949 |       0.9365 |      0.9969 | 
       11 | 00m33s |   -2.66158 |             0.6491 |      5.0802 |             29.9952 |            10.8271 |       176.0736 |      0.3629 |       0.1940 |      0.6537 | 
       12 | 00m31s |   -2.66521 |             0.4597 |     14.7685 |             28.4416 |            19.9523 |       135.5844 |      0.6747 |       0.9869 |      0.9979 | 
       13 | 00m43s |   -2.65927 |             0.6628 |      5.3680 |             10.1050 |            19.5642 |       296.3798 |      0.2352 |       0.6112 |      0.6804 | 
       14 | 00m37s |   -2.64896 |             0.8293 |     14.2701 |             10.5531 |            19.6700 |       213.0500 |      0.5891 |       0.3018 |      0.6622 | 
       15 | 00m42s |   -2.65717 |             0.7135 |     14.9457 |             29.4848 |            18.9740 |       300.9754 |      0.2594 |       0.9236 |      0.7721 | 
       16 | 00m27s |   -2.65335 |             0.9507 |     13.2854 |             10.5314 |             3.3486 |       100.1317 |      0.9374 |       0.1866 |      0.7813 | 
       17 | 00m46s |   -2.65510 |             0.5692 |      5.1025 |             29.8449 |            19.7204 |       339.3589 |      0.4485 |       0.8780 |      0.8208 | 
       18 | 00m35s |   -2.65088 |             0.7568 |      5.3993 |             10.0130 |            19.3298 |       163.8618 |      0.4143 |       0.5322 |      0.8676 | 
       19 | 00m36s |   -2.64531 |             0.7213 |     14.7743 |             28.2536 |            19.9979 |       193.2844 |      0.8638 |       0.2978 |      0.6375 | 
       20 | 00m52s |   -2.64566 |             0.8927 |      5.0751 |             29.3274 |             3.1158 |       303.7612 |      0.3423 |       0.2623 |      0.9909 | 
       21 | 00m40s |   -2.65703 |             0.5948 |     14.2128 |             10.4053 |            18.8860 |       169.5886 |      0.8990 |       0.1340 |      0.5641 | 
       22 | 00m36s |   -2.65216 |             0.5323 |     14.5715 |             10.3926 |             3.0552 |       256.2554 |      0.3696 |       0.9471 |      0.9737 | 
       23 | 00m56s |   -2.67228 |             0.8905 |     13.4996 |             10.0089 |            19.9199 |       399.1767 |      0.3840 |       0.6482 |      0.6469 | 
       24 | 00m33s |   -2.65297 |             0.8447 |      5.3688 |             10.1152 |             3.0179 |       122.9200 |      0.3059 |       0.1450 |      0.7361 | 
       25 | 00m43s |   -2.64087 |             0.9082 |      5.5366 |             24.8313 |            19.7798 |       233.6680 |      0.1137 |       0.8898 |      0.5926 | 
       26 | 00m53s |   -2.65588 |             0.8573 |      6.0967 |             11.2545 |             3.2333 |       323.8915 |      0.2285 |       0.9495 |      0.6646 | 
       27 | 00m46s |   -2.65086 |             0.8884 |     14.7948 |             29.1762 |            18.3960 |       225.9757 |      0.1427 |       0.7460 |      0.8326 | 
       28 | 00m45s |   -2.65912 |             0.5100 |      5.1808 |             10.9143 |            19.8540 |       252.4993 |      0.9033 |       0.9365 |      0.9709 | 
       29 | 00m50s |   -2.64676 |             0.8793 |      6.5897 |             13.6991 |             3.0135 |       231.6314 |      0.3555 |       0.7668 |      0.5065 | 
       30 | 00m53s |   -2.64426 |             0.9111 |      5.1766 |             24.9624 |             3.0822 |       270.3082 |      0.2562 |       0.8488 |      0.5413 | 
    
    {'n_estimators': 233.66804031835815, 'min_child_weight': 19.779801944146204, 'colsample_bytree': 0.9081747519556235, 'max_depth': 5.5366426714428965, 'subsample': 0.5925594065891966, 'reg_lambda': 0.8897581919934189, 'reg_alpha': 0.11372185364899876, 'min_child_samples': 24.8313091372136}
    
     
     
    9人点赞
     
     
     
  • 相关阅读:
    SVN库迁移整理方法----官方推荐方式
    SVN跨版本库迁移目录并保留提交日志
    微信公众号 发送图文消息
    Egret白鹭开发微信小游戏排行榜功能
    双滑动列表实现
    unity之资深工程师
    unity之高级工程师
    lua踩坑系列之浅拷贝与深拷贝
    lua之table.remove你不知道的坑
    unity之Layout Group居中显示
  • 原文地址:https://www.cnblogs.com/cx2016/p/12899522.html
Copyright © 2011-2022 走看看