zoukankan      html  css  js  c++  java
  • sklearn的GridSearchCV——网格搜索超参数调优

    基本使用

    参数不冲突

    参数不冲突时,直接用一个字典传递参数和要对应的候选值给GridSearchCV即可

    我这里的参数冲突指的是类似下面这种情况:
    ① 参数取值受限:参数a='a'时,参数b只能取'b',参数a='A'时,参数b能取'b'或'B'
    ② 参数互斥:参数 a 或 b 二者只能选一个

    from sklearn import datasets
    from sklearn.svm import SVC
    from sklearn.model_selection import GridSearchCV
    iris = datasets.load_iris()
    model = SVC(random_state=seed)
    
    # 需调参数及候选值
    parameters = {
        'C': [0.1, 1, 10], 
        'kernel': ['rbf', 'linear']
    }
    
    # 评价依据
    # https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
    scores = {
        'acc': 'accuracy',         # 准确率
        'f1_mi': 'f1_micro',       # 一种多分类f1值
    }
    
    # 网格搜索实例
    gs = GridSearchCV(
        model,
        parameters,
        cv=5,                      # 交叉验证数
        scoring=scores,            # 评价指标
        refit='f1_mi',             # 在此指标下,用最优表现的参数重新训练模型
    #     return_train_score=True,   # gs.cv_results_额外保存训练集的评价结果
        verbose=1,                 # 日志信息,默认0不输出
        n_jobs=2                   # 并行加速
    )
    
    # 一共要跑的任务数=参数1候选值*...*参数i候选值*交叉验证数
    # 这里就是3*2*5=30
    gs.fit(iris.data, iris.target)

    借助 make_scorer 可以自定义评价指标,如果指标越小越好,那么需要设置greater_is_better=False,sklearn会将这样的指标取负,越小越好取负之后就等同于越大越好。

    from sklearn.metrics import make_scorer
    def custom_loss_func(y_true, y_pred):
        return len(y_true[y_true!=y_pred])/len(y_true)
    # greater_is_better=False,指标越小越好
    # needs_proba=False,指标通过标签计算,不是通过概率
    loss_socre = make_scorer(custom_loss_func, greater_is_better=False, needs_proba=False)
    scores = {
        'acc': 'accuracy',         # 准确率
        'f1_mi': 'f1_micro',       # 一种多分类f1值
        'loss': loss_socre         # 自定义评价指标
    }

    再通过 gs.best_params_ 获取最优模型的参数,gs.best_estimator_取得最优模型(想这样操作的话GridSearchCV的refit参数不能为False),

    print("最优参数")
    print(gs.best_params_)
    print("最佳模型的评分")
    print(gs.best_score_)
    print("最优模型")
    best_model = gs.best_estimator_  # GridSearchCV的refit参数不能为False

    gs.cv_results_ 存放了网格搜索的结果,如果想查看可以借助pandas,我们这里只列出了和评价指标有关的结果

    """
    用表格查看训练信息
    """
    cv_results = pd.DataFrame(gs.cv_results_)
    # 查看其他指标的结果和参数,比如这里按平均准确率排序
    cv_results = cv_results.sort_values(by="mean_test_acc", ascending=False)
    shown_columns = ["mean_test_"+col for col in scores.keys()] + ["params"]
    cv_results[shown_columns].head(3)

    参数冲突

    参数冲突时,互斥参数搜索空间用不同字典来描述,然后将这些字典放到列表中,再传递给GridSearchCV

    parameters = [
        {
            'C': [0.1, 1, 10], 
            'kernel': ['rbf', 'linear']
        },
        {
            'C': [0.1, 1, 10],
            'kernel': ['poly'],
            'degree': [1, 3, 5]
        }
    ]

    复合调参

    管道可以用来连接多个操作,比如特征选择+模型训练,数据处理+模型训练等等。如果这些操作也有参数可调,可以用 GridSearchCV 对它们一起调参

    from sklearn import datasets
    from sklearn.pipeline import Pipeline
    from sklearn.feature_selection import SelectKBest, chi2, f_classif
    from sklearn.svm import SVC
    from sklearn.model_selection import GridSearchCV
    iris = datasets.load_iris()
    
    pipe = Pipeline([
        ('selector', SelectKBest()),       # 特征选择
        ('model', SVC(random_state=seed))  # 模型
    ])
    
    # “双下划线”指定要调整的部件及其参数
    parameters = [
        {
            'selector__score_func': [chi2, f_classif],
            'selector__k': [2, 3, 4],
            'model__C': [0.1, 1, 10], 
            'model__kernel': ['rbf', 'linear']
        },
        {
            'selector__score_func': [chi2, f_classif],
            'selector__k': [2, 3, 4],
            'model__C': [0.1, 1, 10],
            'model__kernel': ['poly'],
            'model__degree': [1, 3, 5]
        }
    ]
    
    
    gs = GridSearchCV(
        pipe,
        parameters,
        cv=5,
        scoring='accuracy',
        verbose=1,
        n_jobs=2,
    )
    
    gs.fit(iris.data, iris.target)

    这时候获得的 best_estimator_ 是管道,我们可以用索引获取需要的组件(特征选择器或模型)

    print("最优组合")
    # best_pipe = gs.best_estimator_
    best_selector = gs.best_estimator_[0]
    best_model = gs.best_estimator_[1]
  • 相关阅读:
    log4Net使用
    VS Code入门
    用VS Code写Python
    C#(99):LINQ查询操作符实例
    C#(99):LINQ to Objects(2)
    spring mvc 配置对静态资源的访问
    EntLib 自动数据库连接字符串加密
    块级格式化上下文( Block formatting contexts)
    Entlib DAAB映射枚举类型
    js 继承
  • 原文地址:https://www.cnblogs.com/dogecheng/p/12791132.html
Copyright © 2011-2022 走看看