zoukankan      html  css  js  c++  java
  • KNN算法网格搜索最优参数

    主要用到 sklearn.model_selection包下的GridSearchCV类。

    总共分为几步:

             a.创建训练集和测试集

          b.创建最优参数字典

          c.构建GridSearchCV对象

          d.进行数据训练

          e.得出最优超参数

    a.创建训练集和测试集

    import numpy as np
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from sklearn.model_selection import GridSearchCV
    
    digits = datasets.load_digits()
    
    x = digits.data
    y = digits.target
    
    x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=666)
    

     b.创建最优参数字典

    param_digits =[
        {
            'weights':['uniform'],
            'n_neighbors':[i for i in range(1,11)]
        },
        {
            'weights':['distance'],
            'n_neighbors':[i for i in range(1,6)],
            'p':[i for i in range(1,6)]
        }
    ]
    

     c.构建GridSearchCV对象

    knn_grid = KNeighborsClassifier()
    
    grid_search = GridSearchCV(knn_grid,param_digits,n_jobs=-1,verbose=2)#n_job指的是所用的核数,也就是多线程执行,当等于-1时,也就是等于你的计算机的核数,verbose越大,打印的信息越详细
    

     d.进行数据训练

    grid_search.fit(x_train,y_train)
    

    e.得出最优超参数 

    param = grid_search.best_params_
    
    print(param)
    
  • 相关阅读:
    BZOJ3196: Tyvj 1730 二逼平衡树
    (转载)你真的会二分查找吗?
    Codeforces Round #259 (Div. 2)
    BZOJ1452: [JSOI2009]Count
    BZOJ2733: [HNOI2012]永无乡
    BZOJ1103: [POI2007]大都市meg
    BZOJ2761: [JLOI2011]不重复数字
    BZOJ1305: [CQOI2009]dance跳舞
    挖坑#4-----倍增
    BZOJ1042: [HAOI2008]硬币购物
  • 原文地址:https://www.cnblogs.com/lyr999736/p/10665572.html
Copyright © 2011-2022 走看看