zoukankan      html  css  js  c++  java
  • sklearn.model_selection.RandomizedSearchCV随机搜索超参数

    GridSearchCV可以保证在指定的参数范围内找到精度最高的参数,但是这也是网格搜索的缺陷所在,它要求遍历所有可能参数的组合,在面对大数据集和多参数的情况下,非常耗时。这也是我通常不会使用GridSearchCV的原因,一般会采用后一种RandomizedSearchCV随机参数搜索的方法

    RandomizedSearchCV的使用方法其实是和GridSearchCV一致的,但它以随机在参数空间中采样的方式代替了GridSearchCV对于参数的网格搜索,在对于有连续变量的参数时,RandomizedSearchCV会将其当作一个分布进行采样这是网格搜索做不到的,它的搜索能力取决于设定的n_iter参数

    函数用法:

    class sklearn.model_selection.RandomizedSearchCV(estimator, param_distributions, *, n_iter=10, 
    scoring=None, n_jobs=None, iid='deprecated', refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
    random_state=None, error_score=nan, return_train_score=False)

    参数详解:

    estimator:估计器

    param_distributions 字典或字典列表:参数字典,key是参数名,values是参数范围

    n_iter int,默认= 10:抽取样本是训练次数

    更多参数参考:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html

    RandomSearchCV是如何"随机搜索"的

    考察其源代码,其搜索策略如下:
    (a)对于搜索范围是distribution的超参数,根据给定的distribution随机采样;
    (b)对于搜索范围是list的超参数,在给定的list中等概率采样;
    (c)对a、b两步中得到的n_iter组采样结果,进行遍历。
    (补充)如果给定的搜索范围均为list,则不放回抽样n_iter次。

    import numpy as np
    from scipy.stats import randint as sp_randint
    from sklearn.model_selection import RandomizedSearchCV
    from sklearn.datasets import load_digits
    from sklearn.ensemble import RandomForestClassifier
    
    # 载入数据
    digits = load_digits()
    X, y = digits.data, digits.target
    
    # 建立一个分类器或者回归器
    clf = RandomForestClassifier(n_estimators=20)
    
    # 给定参数搜索范围:list or distribution
    param_dist = {"max_depth": [3, None],                     #给定list
                  "max_features": sp_randint(1, 11),          #给定distribution
                  "min_samples_split": sp_randint(2, 11),     #给定distribution
                  "bootstrap": [True, False],                 #给定list
                  "criterion": ["gini", "entropy"]}           #给定list
    
    # 用RandomSearch+CV选取超参数
    n_iter_search = 20
    random_search = RandomizedSearchCV(clf, param_distributions=param_dist,
                                       n_iter=n_iter_search, cv=5, iid=False)
    clf=random_search.fit(X, y)
    clf.best_params_ 
    {'bootstrap': False,
     'criterion': 'entropy',
     'max_depth': None,
     'max_features': 9,
     'min_samples_split': 8}
  • 相关阅读:
    MySQL-索引
    MySQL-存储引擎
    MySQL-基本概念
    Elasticsearch-分片原理2
    Elasticsearch-分片原理1
    [NOIP模拟33]反思+题解
    [NOIP模拟测试32]反思+题解
    [NOIP模拟测试31]题解
    [jzoj5840]Miner 题解(欧拉路)
    [NOIP模拟测试30]题解
  • 原文地址:https://www.cnblogs.com/cgmcoding/p/13634531.html
Copyright © 2011-2022 走看看