zoukankan      html  css  js  c++  java
  • sklearn中的Pipeline

      在将sklearn中的模型持久化时,使用sklearn.pipeline.Pipeline(stepsmemory=None)各个步骤串联起来可以很方便地保存模型。

      例如,首先对数据进行了PCA降维,然后使用logistic regression进行分类,如果不使用pipeline,那么我们将分别保存两部分内容,一部分是PCA模型,一部分是logistic regression模型,稍微有点不方便。(当然,这么做也完全可以,使用Pipeline只是提供个方便罢了)

    1.Pipeline中的steps

      Pipeline的最后一步是一个“estimator”(sklearn中实现的各种机器学习算法实例,或者实现了estimator必须包含的方法的自定义类实例),之前的每一步都是“transformer”(必须实现fit和transform方法,比如MinMaxScaler、PCA、one-hot)。在Pipeline调用fit方法时,Pipeline中的每一步依次进行fit操作。

     1 import numpy as np
     2 
     3 from sklearn import linear_model, decomposition, datasets
     4 from sklearn.pipeline import Pipeline
     5 from sklearn.model_selection import GridSearchCV
     6 from sklearn.metrics import accuracy_score
     7 from sklearn.externals import joblib
     8 
     9 logistic = linear_model.LogisticRegression()
    10 
    11 pca = decomposition.PCA()
    12 pipe = Pipeline(steps=[('pca', pca), ('logistic', logistic)])
    13 
    14 digits = datasets.load_digits()
    15 X_digits = digits.data
    16 y_digits = digits.target
    17 
    18 # Parameters of pipelines can be set using ‘__’ separated parameter names:
    19 params = {
    20     'pca__n_components': [20, 40, 64],
    21     'logistic__C': np.logspace(-4, 4, 3),
    22 }
    23 estimator = GridSearchCV(pipe, params)
    24 estimator.fit(X_digits, y_digits)
    25 
    26 # When "estimator" predicts, actually "estimator.best_estimator_" is predicting.
    27 print(type(estimator.best_estimator_))
    28 
    29 y_pred = estimator.predict(X_digits)
    30 print(accuracy_score(y_true=y_digits, y_pred=y_pred))
    31 
    32 # Save model
    33 joblib.dump(estimator, 'models/pca_LR.pkl')

    2.Pipeline中的memory参数

      默认为None,当需要保存Pipeline中间的“transformer”时,才需要用到memory参数。

    3.参考文献

      Pipelining: chaining a PCA and a logistic regression

      

  • 相关阅读:
    php启用zlib压缩文件
    理解MySQL——架构与概念
    二级域名session 共享方案
    SessionID的本质
    PHP核心技术笔记(1):面向对象的核心概念
    改掉这些坏习惯,让你从php菜鸟变php高手
    理解MySQL——索引与优化
    [转]步步教你如何修改OS/400缺省的登陆画面
    [转]Delphi中的线程类
    [转]MSSQL重复记录处理
  • 原文地址:https://www.cnblogs.com/wuliytTaotao/p/9329695.html
Copyright © 2011-2022 走看看