zoukankan      html  css  js  c++  java
  • 【学习笔记】模型的选择与调优

    交叉验证

    目的:为了让被评估的模型更加准确可信。

    交叉验证:将拿到的数据,分为训练和验证集。以下图为例:将数据分成5份,其中一份作为验证集。然后经过次(组)的测试,每次都更换不同的验证集。即得到5组模型的结果,取平均值作为最终结果。又称5折交叉验证。

    img

    超参数搜索-网格搜索

    通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合。每组超参数都采用交叉验证来进行评估。最后选出最优参数组合建立模型。

    img

    超参数搜索-网格搜索API

    sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None):对估计器的指定参数值进行详尽搜索

    参数:

    • estimator:估计器对象
    • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
    • cv:指定几折交叉验证

    方法:

    • fit:输入训练数据
    • score:准确率

    属性:

    • best_score_:在交叉验证中测试的最好结果
    • best_estimator_:最好的参数模型
    • cv_results_:每次交叉验证后的测试集准确率结果和训练集准确率结果

    【学习笔记】分类算法-k近邻算法中的“预测用户签到位置”改成网格搜索

    from sklearn.model_selection import GridSearchCV
    ...
    gc = GridSearchCV(knn, param_grid={"n_neighbors": [1, 3, 5, 10]}, cv=2)
    gc.fit(x_train, y_train.astype("int"))
    print("在测试集上的准确率:", gc.score(x_test, y_test.astype("int")))
    print("在交叉验证中最后的结果:", gc.best_params_)
    print("最好的模型是:", gc.best_estimator_)
    print("每个超参数每次的结果为:", gc.cv_results)
    

    结果:

    在测试集上的准确率: 0.8293838862559242
    在交叉验证中最后的结果: {'n_neighbors': 10}
    最好的模型是: KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
               metric_params=None, n_jobs=None, n_neighbors=10, p=2,
               weights='uniform')
    每个超参数每次的结果为: {'mean_fit_time': array([0.00898993, 0.00898921, 0.00849307, 0.01098037]), 'std_fit_time': array([6.79492950e-06, 7.74860382e-06, 9.17911530e-06, 1.51014328e-03]), 'mean_score_time': array([0.47162163, 0.62682521, 0.71092987, 0.84417915]), 'std_score_time': array([0.00648773, 0.00649297, 0.00772619, 0.00073266]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 10],
                 mask=[False, False, False, False],
           fill_value='?',
                dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 10}], 'split0_test_score': array([0.77042226, 0.82359905, 0.83149171, 0.8343528 ]), 'split1_test_score': array([0.773846  , 0.82554117, 0.83285559, 0.8342394 ]), 'mean_test_score': array([0.77213252, 0.8245692 , 0.83217301, 0.83429615]), 'std_test_score': array([1.71187149e-03, 9.71057298e-04, 6.81938147e-04, 5.67014065e-05]), 'rank_test_score': array([4, 3, 2, 1]), 'split0_train_score': array([1.        , 0.88316695, 0.86497974, 0.85133933]), 'split1_train_score': array([1.        , 0.88190608, 0.8636543 , 0.84490923]), 'mean_train_score': array([1.        , 0.88253651, 0.86431702, 0.84812428]), 'std_train_score': array([0.        , 0.00063043, 0.00066272, 0.00321505])}
    
    
  • 相关阅读:
    spring boot RESTfuldemo测试类
    再谈Redirect(客户端重定向)和Dispatch(服务器端重定向)
    HTTP协议解析
    HTTP协议详解(真的很经典)
    JMeter进行简单的数据库(mysql)压力测试
    LoadRunner利用ODBC编写MySql脚本
    性能瓶颈的分析
    bug的处理流程
    Loadrunner11 录制手机App脚本多种方法介绍
    利用fiddler录制脚本
  • 原文地址:https://www.cnblogs.com/zhangfengxian/p/10561147.html
Copyright © 2011-2022 走看看