zoukankan      html  css  js  c++  java
  • sklearn快速入门

    原创博文,转载请注明出处。
    (为了节约空间,打印结果常用“...”表示省略)

    一、加载数据集

    1. 加载sklearn自带的数据集

    scikit-learn有一些自带的标准数据集,例如用于分类的经典数据集iris和digits以及用于回归的boston house prices数据集。
    这些自带的数据集一种是类似字典的对象,它保存所有的数据(通常情况下,特征向量存储在.data成员中,在监督学习中,标签存储在.target成员中)和关于数据的元数据(如.target_names成员用来存储各个标签值对应的含义标签名称)。每个数据集中包含的成员不一定是一样的,既然数据集是一种类似字典的对象,那么,我们就可以通过“数据集名.keys()”来查看该数据集中,究竟有哪些成员,从而对数据集有个整体的把握。

    from sklearn import datasets
    iris = datasets.load_iris()
    print iris
    
    {'target_names': array(['setosa', 'versicolor', 'virginica'], 
          dtype='|S10'), 'data': array([[ 5.1,  3.5,  1.4,  0.2],
           [ 4.9,  3. ,  1.4,  0.2],
           [ 4.7,  3.2,  1.3,  0.2],
           ...
           [ 6.2,  3.4,  5.4,  2.3],
           [ 5.9,  3. ,  5.1,  1.8]]), 'target': array([0, 0, 0, 0, ···2, 2, 2]), ...}
    

    2. 访问自带数据集成员

    载入数据集后,可以通过“数据集名.成员名”的方式访问成员。

    访问特征集

    print iris.data
    
    [[ 5.1  3.5  1.4  0.2]
     [ 4.9  3.   1.4  0.2]
     [ 4.7  3.2  1.3  0.2]
     ...
     [ 6.2  3.4  5.4  2.3]
     [ 5.9  3.   5.1  1.8]]
    

    访问标签集

    print iris.target
    
    [0 0 ... 0 0 0 1 1 ... 1 1 2 2 ... 2 2]
    

    3. 加载数据非二维数组的数据集demo

    # _*_ coding:utf-8_*_
    from sklearn import datasets
    digits = datasets.load_digits()
    print digits.keys()
    print '------'
    # 第0个样本image为
    print digits.images[0]
    print '------'
    print digits.data[0]
    
    ['images', 'data', 'target_names', 'DESCR', 'target']
    ------
    [[  0.   0.   5.  13.   9.   1.   0.   0.]
     [  0.   0.  13.  15.  10.  15.   5.   0.]
     [  0.   3.  15.   2.   0.  11.   8.   0.]
     [  0.   4.  12.   0.   0.   8.   8.   0.]
     [  0.   5.   8.   0.   0.   9.   8.   0.]
     [  0.   4.  11.   0.   1.  12.   7.   0.]
     [  0.   2.  14.   5.  10.  12.   0.   0.]
     [  0.   0.   6.  13.  10.   0.   0.   0.]]
    ------
    [  0.   0.   5.  13.   9.   1.   0.   0.   0.   0.  13.  15.  10.  15.   5.
       0.   0.   3.  15.   2.   0.  11.   8.   0.   0.   4.  12.   0.   0.   8.
       8.   0.   0.   5.   8.   0.   0.   9.   8.   0.   0.   4.  11.   0.   1.
      12.   7.   0.   0.   2.  14.   5.  10.  12.   0.   0.   0.   0.   6.  13.
      10.   0.   0.   0.]
    

    可以看到.images和.data的区别:.data将.images中的元素由二维数组转为一维向量。

    4. 加载外部数据

    scikit-learn接受numpy array或者scipy稀疏矩阵这样的数值型数据,另外,它也和pandas里的DataFrame是兼容的。

    二、拟合分类器

    1. 分类器的定义

    在scikit-learn中,分类评估器是一种python类,具有fit(X, y)和predict(T)两个方法。
    过程是:用训练集fit某个分类器,然后用这个fit好的分类器实例来predict测试集

    2. demo:作用在digits数据集上的SVC(分类器)模型的建立

    SVC的全称为support vector classification,即:支持向量分类器,svm中定义的python类。

    2.1 基础原型

    digits数据集共有1797个样本,我们先简单地取第0个到第1759个作为训练集,第1760到第1796个作为测试集

    # _*_ coding:utf-8_*_
    from sklearn import datasets
    from sklearn import svm
    clf = svm.SVC()
    digits = datasets.load_digits()
    X, y = digits.data, digits.target
    clf.fit(X[0:1760], y[0:1760])    # 拟合SVC分类器实例
    
    # 对比第1760到1796个实例的预测值和真实值,查看效果
    print clf.predict(X[1760:1797])    # 用SVC分类器实例预测测试集
    print y[1760:1797]
    
    [1 7 3 3 3 3 1 3 3 5 3 3 3 6 3 3 5 3 3 3 2 3 2 3 5 7 9 3 4 3 3 4 9 3 3 3 3]
    [1 7 6 8 4 3 1 4 0 5 3 6 9 6 1 7 5 4 4 7 2 8 2 2 5 7 9 5 4 8 8 4 9 0 8 9 8]
    

    2.2 调参

    在new一个分类器实例的时候,可以对分类器的参数进行设置,例如:
    将上面的代码中,SVC分类器参数稍做设置

    clf = svm.SVC(gamma=0.001, C=100.)
    

    结果对比,预测准确度大大提高(注意第二行是真实值,第一行才是预测值)

    [1 7 6 8 4 3 1 4 0 5 3 6 9 6 1 7 5 4 4 7 2 8 2 2 5 7 9 5 4 8 8 4 9 0 8 9 8]
    [1 7 6 8 4 3 1 4 0 5 3 6 9 6 1 7 5 4 4 7 2 8 2 2 5 7 9 5 4 8 8 4 9 0 8 9 8]
    

    上面的调参过程是手动完成的,可以使用网格搜索和交叉验证的方法自动找到合适的参数。

    三、 模型存储

    我们建好模型之后,需要将模型存储起来,以供日后使用。
    python内置的pickle模块可以用来存储模型,但是,scikit-learn中还有类似功能的joblib类。它不可以像pickle一样将模型保存到字符串中,但它可以将模型保存在磁盘上,而且大数据应用场景下joblib更高效。
    下面我们将上面训练好的模型clf保存在SVC_clf.pkl文件中:
    将下面的代码加入上面代码的最后一行

    from sklearn.externals import joblib
    joblib.dump(clf, 'svc_clf.pkl')
    

    这样,以后需要用到模型只要重新joblib.load该模型即可,也可以在其他python进程中使用。
    例如,可以在同一个项目的另一个python文件中,重新导入该模型,进行预测:

    from sklearn import datasets
    from sklearn import svm
    digits = datasets.load_digits()
    X, y = digits.data, digits.target
    
    from sklearn.externals import joblib
    clf = joblib.load('svc_clf.pkl')
    print clf.predict(X[1760:1797])
    print y[1760:1797]
    
    [1 7 6 8 4 3 1 4 0 5 3 6 9 6 1 7 5 4 4 7 2 8 2 2 5 7 9 5 4 8 8 4 9 0 8 9 8]
    [1 7 6 8 4 3 1 4 0 5 3 6 9 6 1 7 5 4 4 7 2 8 2 2 5 7 9 5 4 8 8 4 9 0 8 9 8]
    

    运行结果和之前一样,说明是同一个模型。

    ps: pickle是腌黄瓜的意思,dump方法和load方法的命名很生动地表示出了:将模型倒入(dump)腌制容器中,以备日后取出(load)使(食)用。而joblib继承了pickle中两个方法的命名方法。

    四 约定

    1. 类型转换

    *所有输入的类型都会被转化为float64
    *回归预测输出转化为float64,分类预测输出不变

    2. 更改参数、重新拟合

    分类器可以通过.set_params()方法修改超参数,更改后再调用.fit()方法,将重写之前的模型。
    例如,将上面代码中,clf拟合后,加入以下代码:

    clf.set_params(C=0.01).fit(X,y)
    

    得到:

    [1 1 6 3 1 3 3 1 3 3 3 1 3 6 3 3 3 1 1 3 3 3 3 3 3 3 3 3 1 3 3 1 3 0 3 3 3]
    [1 7 6 8 4 3 1 4 0 5 3 6 9 6 1 7 5 4 4 7 2 8 2 2 5 7 9 5 4 8 8 4 9 0 8 9 8]
    

    预测结果效果下降很多,说明分类器超参数设置成功,新模型覆盖旧模型。

  • 相关阅读:
    跨平台编译ceres for Android
    解决OpenCV JavaCameraView相机preview方向问题
    OpenCV 4.0.1 找不到R.styleable解决
    mumu模拟器安装xposed--如何在android模拟器上进行root
    Windows编译OpenCV4Android解决undefined reference to std错误
    Skeleton with Assimp 骨骼动画解析
    Android GL deadlock timeout error
    Android device debug (adb) by Charge Only mode
    Firefox 多行标签的解决方案分享
    Linux 工程向 Windows 平台迁移的一些小小 tips
  • 原文地址:https://www.cnblogs.com/DianeSoHungry/p/8166800.html
Copyright © 2011-2022 走看看