zoukankan      html  css  js  c++  java
  • Python超参数自动搜索模块GridSearchCV上手

    1. 引言

    当我们跑机器学习程序时,尤其是调节网络参数时,通常待调节的参数有很多,参数之间的组合更是繁复。依照注意力>时间>金钱的原则,人力手动调节注意力成本太高,非常不值得。For循环或类似于for循环的方法受限于太过分明的层次,不够简洁与灵活,注意力成本高,易出错。本文介绍sklearn模块的GridSearchCV模块,能够在指定的范围内自动搜索具有不同超参数的不同模型组合,有效解放注意力。

    2. GridSearchCV模块简介

    这个模块是sklearn模块的子模块,导入方法非常简单

    from sklearn.model_selection import GridSearchCV

    函数原型:

    class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)

    其中cv可以是整数或者交叉验证生成器或一个可迭代器,cv参数对应的4种输入列举如下:

    1. None:默认参数,函数会使用默认的3折交叉验证
    2. 整数k:k折交叉验证。对于分类任务,使用StratifiedKFold(类别平衡,每类的训练集占比一样多,具体可以查看官方文档)。对于其他任务,使用KFold
    3. 交叉验证生成器:得自己写生成器,头疼,略
    4. 可以生成训练集与测试集的迭代器:同上,略

    3. 分析结果自动保存

    逗号分隔值(Comma-Separated Values,CSV,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。纯文本意味着该文件是一个,不含必须像二进制数字那样被解读的数据。CSV文件由任意数目的记录组成,记录间以某种换行符分隔;每条记录由字段组成,字段间的分隔符是其它字符或字符串,最常见的是逗号或制表符。通常,所有记录都有完全相同的字段序列。

    CSV文件有个突出的优点,可以用excel等软件打开,比起记事本和matlab、python等编程语言界面,便于查看、制作报告、后期整理等。

    GridSearchCV模块中,不同超参数的组合方式及其计算结果以字典的形式保存在 clf.cv_results_中,python的pandas模块提供了高效整理数据的方法,只需要3行代码即可解决问题。

    cv_result = pd.DataFrame.from_dict(clf.cv_results_)
    with open('cv_result.csv','w') as f:
      cv_result.to_csv(f)

    4. 完整例程

    代码清晰易懂,无须解释。https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search

     1 import pandas as pd
     2 from sklearn import svm, datasets
     3 from sklearn.model_selection import GridSearchCV
     4 from sklearn.metrics import classification_report
     5 
     6 iris = datasets.load_iris()
     7 parameters = {'kernel':('linear', 'rbf'), 'C':[1, 2, 4], 'gamma':[0.125, 0.25, 0.5 ,1, 2, 4]}
     8 svr = svm.SVC()
     9 clf = GridSearchCV(svr, parameters, n_jobs=-1)
    10 clf.fit(iris.data, iris.target)
    11 cv_result = pd.DataFrame.from_dict(clf.cv_results_)
    12 with open('cv_result.csv','w') as f:
    13     cv_result.to_csv(f)
    14     
    15 print('The parameters of the best model are: ')
    16 print(clf.best_params_)
    17 
    18 y_pred = clf.predict(iris.data)
    19 print(classification_report(y_true=iris.target, y_pred=y_pred))

    5. 相关资料

    1. sklearn.model_selection.GridSearchCV模块主页: http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
    2. pandas.DataFrame模块主页:http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html
    3. 本文例程 https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search

    6.未来展望

       当前的工作局限于算法超参数搜索,还没有结合预处理方式自动搜索、不同算法之间自动搜索、不同深度学习模型自动搜索等。如何利用pipeline、keras、tf等模块,实现整个环节的自动搜索,是下一步学习与总结的方向。

  • 相关阅读:
    《数据密集型应用系统设计》读书笔记
    每周总结
    每周总结
    每周总结
    《数据密集型应用系统设计》读书笔记
    每周总结
    《重构》读书笔记
    每周总结
    软件过程与管理知识回顾
    操作系统知识汇总5-6章
  • 原文地址:https://www.cnblogs.com/nwpuxuezha/p/6618205.html
Copyright © 2011-2022 走看看