zoukankan      html  css  js  c++  java
  • GridSearchCV网格搜索得到最佳超参数, 在K近邻算法中的应用

      最近在学习机器学习中的K近邻算法, KNeighborsClassifier 看似简单实则里面有很多的参数配置, 这些参数直接影响到预测的准确率. 很自然的问题就是如何找到最优参数配置? 这就需要用到GridSearchCV 网格搜索模型. 

      在没有学习到GridSearchCV 网格搜索模型之前, 寻找最优参数配置是通过人为改变参数, 来观察预测结果准确率的. 具体步骤如下:

    1. 修改参数配置
    2. fit 训练集
    3. 预测测试集
    4. 预测结果与真实结果对比
    5. 重复上述步骤

      GridSearchCV 网格搜索模型寻找最优参数的步骤如下:

    1. 将各种参数配置封装为列表
    2. 实例化分类器
    3. 使用GridSearchCV 为分类器和参数建模
    4. 实例化模型, 并用新的模型对象fit训练集
    5. 得到最好的参数配置
    6. 用最优参数去预测数据

      于是我的疑问就来了, GridSearchCV 并没有去预测测试集,进而得到预测结果,并在与真实结果的对比中找到最优的参数配置, 没有这个步骤,它是怎么得到最优参数的? 搜索了很多,终于在这个网页中得到了想要的信息: python – GridSearchCV是否执行交叉验证? http://www.cocoachina.com/articles/67515 

      简单说就是我们把训练集传递给GridSearchCV, 它会进一步将训练集分为训练集和测试集, 然后通过不断调整超参数, 进行交叉验证, 最后获得最优参数. 

      GridSearchCV会主动将数据分为训练集和测试集,这就是原因所在了.

      代码实现:

     1 from sklearn import datasets
     2 from sklearn.model_selection import train_test_split
     3 from sklearn.neighbors import KNeighborsClassifier
     4 from sklearn.metrics import accuracy_score
     5 from sklearn.model_selection import GridSearchCV
     6 
     7 
     8 # 1/获取数据
     9 digits = datasets.load_digits()
    10 X = digits.data
    11 y = digits.target
    12 
    13 # 2/分割数据,得到训练集和测试集
    14 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
    15 
    16 
    17 # 3/超参数配置
    18 param_grid = [
    19     {
    20         "weights":["uniform"],
    21         "n_neighbors":[i for i in range(1,11)]
    22     },
    23     {
    24         "weights":["distance"],
    25         "n_neighbors":[i for i in range(1,11)],
    26         "p":[i for i in range(1,6)]
    27     }
    28 ]
    29 
    30 
    31 # 4/为分类器和超参数搭建模型
    32 knn_clf = KNeighborsClassifier()
    33 grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)
    34 
    35 # 5/实例化模型(多种参数配置的分类器)fit训练集,
    36 # 本质上是将训练集进一步分为训练集和测试集,得到最好的参数配置
    37 # 因为要不断尝试各种参数交叉验证,所以非常耗时
    38 grid_search.fit(X_train, y_train)
    39 
    40 # 6/
    41 # 最终拿到最佳参数配置分类器 best_estimator_
    42 knn_clf = grid_search.best_estimator_
    43 
    44 # 7/使用最佳分类器对测试集预测
    45 y_predict = knn_clf.predict(X_test)
    46 
    47 # 8/得到准确率
    48 accuracy_score(y_test, y_predict))
  • 相关阅读:
    iOS使用自签名证书实现HTTPS请求
    DB操作-用批处理执行Sql语句
    SSL通信-忽略证书认证错误
    oracle 19c awr 丢失 i/o信息
    this.$route.query刷新后类型改变
    wx.navigateTo在app.js中偶发性失效
    微信小程序new Date()转换日期格式时iphonex为NaN
    下载cnpm成功,cnpm -v却不识别
    element-ui的表单验证如何清除校验提示语
    5. 最长回文子串(动态规划算法)
  • 原文地址:https://www.cnblogs.com/waterr/p/13394902.html
Copyright © 2011-2022 走看看