zoukankan      html  css  js  c++  java
  • 【机器学习】:Xgboost/LightGBM使用与调参技巧

    机器学习模型当中,目前最为先进的也就是xgboost和lightgbm这两个树模型了。那么我们该如何进行调试参数呢?哪些参数是最重要的,需要调整的,哪些参数比较一般,这两个模型又该如何通过代码进行调用呢?下面是一张总结了xgboost,lightbgm,catboost这三个模型调试参数的一些经验,以及每个参数需要的具体数值以及含义,供大家参考:

    一.Xgboost配合grid search进行网格搜索参数 

    实现代码如下:

    mport xgboost as xgb
    from sklearn import metrics
    from sklearn.model_selection import GridSearchCV
    
    def auc(m, train, test): 
        return (metrics.roc_auc_score(y_train, m.predict_proba(train)[:,1]),
                                metrics.roc_auc_score(y_test, m.predict_proba(test)[:,1]))
    
    # Parameter Tuning
    model = xgb.XGBClassifier()
    param_dist = {"max_depth": [10,30,50],
                  "min_child_weight" : [1,3,6],
                  "n_estimators": [200],
                  "learning_rate": [0.05, 0.1,0.16],}
    grid_search = GridSearchCV(model, param_grid=param_dist, cv = 3, 
                                       verbose=10, n_jobs=-1)
    grid_search.fit(train, y_train)
    
    grid_search.best_estimator_
    
    model = xgb.XGBClassifier(max_depth=3, min_child_weight=1,  n_estimators=20,
                              n_jobs=-1 , verbose=1,learning_rate=0.16)
    model.fit(train,y_train)
    
    print(auc(model, train, test))

    这里使用了自定义的auc作为模型的评价指标,输出如下:

    Fitting 3 folds for each of 27 candidates, totalling 81 fits
    (0.7479275227922775, 0.7430946047035487)

    二.LightGBM配合grid search进行网格搜索参数

    代码如下:

    import lightgbm as lgb
    from sklearn import metrics
    
    def auc2(m, train, test): 
        return (metrics.roc_auc_score(y_train,m.predict(train)),
                                metrics.roc_auc_score(y_test,m.predict(test)))
    
    lg = lgb.LGBMClassifier(silent=False)
    param_dist = {"max_depth": [25,50, 75],
                  "learning_rate" : [0.01,0.05,0.1],
                  "num_leaves": [300,900,1200],
                  "n_estimators": [200]
                 }
    grid_search = GridSearchCV(lg, n_jobs=-1, param_grid=param_dist, cv = 3, 
                               scoring="roc_auc", verbose=5)
    grid_search.fit(train,y_train)
    grid_search.best_estimator_
    
    #使用lgbm原生态的方式进行训练 d_train
    = lgb.Dataset(train, label=y_train, free_raw_data=False) params = {"max_depth": 3, "learning_rate" : 0.1, "num_leaves": 900, "n_estimators": 20} # Without Categorical Features model2 = lgb.train(params, d_train) print(auc2(model2, train, test)) #With Catgeorical Features cate_features_name = ["MONTH","DAY","DAY_OF_WEEK","AIRLINE","DESTINATION_AIRPORT", "ORIGIN_AIRPORT"] model2 = lgb.train(params, d_train, categorical_feature = cate_features_name) print(auc2(model2, train, test))

    第三种引用方式lbgm的方式是,sklearn和lgbm相结合,这样就可以使用sklearn对lgbm的运行结果快速进行评估。

    # coding: utf-8
    import lightgbm as lgb
    import pandas as pd
    from sklearn.metrics import mean_squared_error
    from sklearn.model_selection import GridSearchCV
    
    # 加载数据
    print('加载数据...')
    df_train = pd.read_csv('../data/regression.train.txt', header=None, sep='	')
    df_test = pd.read_csv('../data/regression.test.txt', header=None, sep='	')
    
    # 取出特征和标签
    y_train = df_train[0].values
    y_test = df_test[0].values
    X_train = df_train.drop(0, axis=1).values
    X_test = df_test.drop(0, axis=1).values
    
    print('开始训练...')
    # 直接初始化LGBMRegressor
    # 这个LightGBM的Regressor和sklearn中其他Regressor基本是一致的
    gbm = lgb.LGBMRegressor(objective='regression',
                            num_leaves=31,
                            learning_rate=0.05,
                            n_estimators=20)
    
    # 使用fit函数拟合
    gbm.fit(X_train, y_train,
            eval_set=[(X_test, y_test)],
            eval_metric='l1',
            early_stopping_rounds=5)
    
    # 预测
    print('开始预测...')
    y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
    # 评估预测结果
    print('预测结果的rmse是:')
    print(mean_squared_error(y_test, y_pred) ** 0.5)

    这就是Xgboost/LightGBM的基本代码使用啦!

  • 相关阅读:
    React 16 源码瞎几把解读 【前戏】 为啥组件外面非得包个标签?
    nodejs 使用redis 管理session
    nodejs 优雅的连接 mysql
    mongodb 学习笔记 3 --- 查询
    mongodb 学习笔记 2 --- 修改器
    mongodb 学习笔记--- 基础知识
    看jquery3.3.1学js类型判断的技巧
    FIS3 大白话【一】
    Flutter 插件开发:以微信SDK为例
    最新Android面试题整理,收藏下吧值得拥有!
  • 原文地址:https://www.cnblogs.com/geeksongs/p/15418610.html
Copyright © 2011-2022 走看看