zoukankan      html  css  js  c++  java
  • 模型调参---GridSearchCV

    一. GridSearchCV参数介绍

    导入模块:

    from sklearn.model_selection import GridSearchCV
    GridSearchCV 称为网格搜索交叉验证调参,它通过遍历传入的参数的所有排列组合,通过交叉验证的方式,返回所有参数组合下的评价指标得分,GridSearchCV 函数的参数详细解释如下:
    class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)

    GridSearchCV官方说明

    参数:

    • estimator:scikit-learn 库里的算法模型;
    • param_grid:需要搜索调参的参数字典;
    • scoring:评价指标,可以是 auc, rmse,logloss等;
    • n_jobs:并行计算线程个数,可以设置为 -1,这样可以充分使用机器的所有处理器,并行数量越多,有利于缩短调参时间;
    • iid:如果设置为True,则默认假设数据在每折中具有相同地分布,并且最小化的损失是每个样本的总损失,而不是每折的平均损失。简单点说,就是如果你可以确定 cv 中每折数据分布一致就设置为 True,否则设置为 False;
    • cv:交叉验证的折数,默认为3折;

    常用属性:

    • cv_results_:用来输出cv结果的,可以是字典形式也可以是numpy形式,还可以转换成DataFrame格式
    • best_estimator_:通过搜索参数得到的最好的估计器,当参数refit=False时该对象不可用
    • best_score_:float类型,输出最好的成绩
    • best_params_:通过网格搜索得到的score最好对应的参数
    • best_index_:对应于最佳候选参数设置的索引(cv_results_数组)。cv_results _ [‘params’] [search.best_index_]中的dict给出了最佳模型的参数设置,给出了最高的平均分数(search.best_score_)。
    • scorer_:评分函数
    • n_splits_:交叉验证的数量
    • refit_time_:refit所用的时间,当参数refit=False时该对象不可用

    常用函数:

    • decision_function(X):返回决策函数值(比如svm中的决策距离)
    • fit(X,y=None,groups=None,fit_params):在数据集上运行所有的参数组合
    • get_params(deep=True):返回估计器的参数
    • inverse_transform(Xt):Call inverse_transform on the estimator with the best found params.
    • predict(X):返回预测结果值(0/1)
    • predict_log_proba(X): Call predict_log_proba on the estimator with the best found parameters.
    • predict_proba(X):返回每个类别的概率值(有几类就返回几列值)
    • score(X, y=None):返回函数
    • set_params(**params):Set the parameters of this estimator.
    • transform(X):在X上使用训练好的参数

    属性grid_scores_已经被删除,改用:

    means = grid_search.cv_results_['mean_test_score']
    params = grid_search.cv_results_['params']
    

      

     举例:

    使用多评价指标,必须设置refit参数,可以显示多指标的结果,但是最后显示最佳的参数时候必须指定一个指标,详解:解决方法

    param_test2 = { 'max_depth':[3,4,5,6], 'min_child_weight':[0.5,1,1.5]}
    scorers = {
        'precision_score': make_scorer(precision_score),
        'recall_score': make_scorer(recall_score),
        'accuracy_score': make_scorer(accuracy_score)
    }
    gsearch2 = GridSearchCV(estimator = XGBClassifier(         
        learning_rate =0.1, n_estimators=270, max_depth=4,min_child_weight=1, gamma=0, subsample=0.8,
        colsample_bytree=0.8, objective= 'binary:logistic', nthread=4,
        scale_pos_weight=1, seed=27),
        param_grid = param_test1,scoring=scorers,refit ='precision_score',n_jobs=4,iid=False, cv=5)
    gsearch2.fit(x_train_resampled,y_train_resampled)
    

    查看最佳结果:

    >>>gsearch2.best_params_,gsearch2.best_score_,gsearch2.cv_results_['mean_test_precision_score'],gsearch2.cv_results_['params']
    
    ({'max_depth': 9, 'min_child_weight': 1},
     0.8278796760710192,
     array([0.79985227, 0.80330522, 0.80645782, 0.8223829 , 0.81170396,
            0.80891565, 0.82691152, 0.82032078, 0.82220572, 0.82787968,
            0.82439509, 0.81863326]),
     [{'max_depth': 3, 'min_child_weight': 1},
      {'max_depth': 3, 'min_child_weight': 3},
      {'max_depth': 3, 'min_child_weight': 5},
      {'max_depth': 5, 'min_child_weight': 1},
      {'max_depth': 5, 'min_child_weight': 3},
      {'max_depth': 5, 'min_child_weight': 5},
      {'max_depth': 7, 'min_child_weight': 1},
      {'max_depth': 7, 'min_child_weight': 3},
      {'max_depth': 7, 'min_child_weight': 5},
      {'max_depth': 9, 'min_child_weight': 1},
      {'max_depth': 9, 'min_child_weight': 3},
      {'max_depth': 9, 'min_child_weight': 5}])

    查看交叉验证的中间结果:

    pd.DataFrame(gsearch2.cv_results_)

     画图显示最佳参数:

    grid_visualization = []
    for grid_pair in gsearch2.cv_results_['mean_test_precision_score']:
        grid_visualization.append(grid_pair)
    grid_visualization = np.array(grid_visualization)
    grid_visualization.shape = (4,3)
    sns.heatmap(grid_visualization,annot=True,cmap='Blues',fmt='.3f')
    plt.xticks(np.arange(3)+0.5,gsearch2.param_grid['min_child_weight'])
    plt.yticks(np.arange(4)+0.5,gsearch2.param_grid['max_depth'])
    plt.xlabel('min_child_weight')
    plt.ylabel('max_depth')
    

      

    参考文献:

    【1】集成树模型GridSearchCV,stacking

    【2】python机器学习库sklearn——参数优化(网格搜索GridSearchCV、随机搜索RandomizedSearchCV、hyperopt)

    【3】XGBoost参数调优完全指南

    【4】使用GridSearchCV进行网格搜索(比较全)

    【5】当GridSearch遇上XGBoost 一段代码解决调参问题

  • 相关阅读:
    二分法查找(C语言)
    冒泡排序法(C语言)
    Python 字符串操作方法大全
    guns搭建笔记
    mysql数据库下载及安装
    docker安装
    自动化学习路径及问题汇总目录
    UI自动化使用docker做并行执行
    allure趋势图无数据
    allure报告不显示@Attachment
  • 原文地址:https://www.cnblogs.com/nxf-rabbit75/p/9353061.html
Copyright © 2011-2022 走看看