zoukankan      html  css  js  c++  java
  • 对 load_breast_cancer 进行 SVM 分类

    原创转载请注明出处:https://www.cnblogs.com/agilestyle/p/12786022.html

    SVC 的构造函数

    这里有三个重要的参数 kernel、C 和 gamma

    kernel

    kernel 代表核函数的选择,它有四种选择,只不过默认是 rbf,即高斯核函数。

    • linear:线性核函数
    • poly:多项式核函数
    • rbf:高斯核函数
    • sigmoid:sigmoid 核函数

    这四种函数代表不同的映射方式,在实际工作中,如何选择这 4 种核函数?

    • 线性核函数是在数据线性可分的情况下使用的,运算速度快,效果好。不足在于它不能处理线性不可分的数据。
    • 多项式核函数可以将数据从低维空间映射到高维空间,但参数比较多,计算量大。
    • 高斯核函数同样可以将样本映射到高维空间,但相比于多项式核函数来说所需的参数比较少,通常性能不错,所以是默认使用的核函数。
    • sigmoid 经常用在神经网络的映射中。因此当选用 sigmoid 核函数时,SVM 实现的是多层神经网络。

    4 种核函数,除了第一种线性核函数外,其余 3 种都可以处理线性不可分的数据。

    C

    参数 C 代表目标函数的惩罚系数,惩罚系数指的是分错样本时的惩罚程度,默认情况下为 1.0。当 C 越大的时候,分类器的准确性越高,但同样容错率会越低,泛化能力会变差。相反,C 越小,泛化能力越强,但是准确性会降低。

    gamma

    参数 gamma 代表核函数的系数。

    准备数据

    import numpy as np
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import train_test_split
    from sklearn.svm import SVC
    
    cancer = load_breast_cancer()
    # sklearn.utils.Bunch
    print(type(cancer))
    # dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
    print(cancer.keys())
    
    # print(cancer.DESCR)
    
    features = cancer.data
    labels = cancer.target
    # (569, 30)
    print(features.shape)
    # (569,)
    print(labels.shape)

    分割数据

    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)

    建模训练、评价模型

    线性核

    model_linear = SVC(C=1.0, kernel='linear')  # 线性核
    model_linear.fit(X_train, y_train)
    
    train_score = model_linear.score(X_train, y_train)
    test_score = model_linear.score(X_test, y_test)
    
    print('train_score:{0}; test_score:{1}'.format(train_score, test_score))

    高斯核

    model_rbf = SVC(C=1.0, kernel='rbf', gamma=0.1)  # 高斯核
    model_rbf.fit(X_train, y_train)
    
    train_score = model_rbf.score(X_train, y_train)
    test_score = model_rbf.score(X_test, y_test)
    
    print('train_score:{0}; test_score:{1}'.format(train_score, test_score))
    
    gammas = np.linspace(0, 0.005, 30)
    param_grid = {
        'gamma': gammas
    }
    
    model = GridSearchCV(SVC(), param_grid, cv=10)  # 使用网格搜索和交叉验证找出最佳gamma值
    model.fit(X_train, y_train)
    
    print('best param:{0}; best score:{1}'.format(model.best_params_, model.best_score_))

    多项式核

    model_poly = SVC(kernel='poly', degree=2, gamma='scale')  # 多项式核 2阶
    model_poly.fit(X_train, y_train)
    
    train_score = model_poly.score(X_train, y_train)
    test_score = model_poly.score(X_test, y_test)
    
    print('train_score:{0}; test_score:{1}'.format(train_score, test_score))
    
    degrees = np.linspace(1, 10, 10)
    param_grid = {
        'degree': degrees
    }
    
    model = GridSearchCV(SVC(kernel='poly', gamma='scale'), param_grid, cv=10)
    model.fit(X_train, y_train)
    
    print('best param:{0}; best score:{1}'.format(model.best_params_, model.best_score_))

    Note:输出的最佳参数可能会不一样是因为交叉验证时数据集的划分每次都不一样。可以选择得分最高的,也可以执行多次选择取出现次数最多的那个。

    Reference

    https://time.geekbang.org/column/article/80712

    https://scikit-learn.org/stable/modules/svm.html

    https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html

    https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

  • 相关阅读:
    rabbitmq channel参数详解
    java中反射知识点总结
    SpringBoot的ApplicationRunner
    ServletContextInitializer添加 servlet filter listener
    如何在Job中获取 IOC applicationcontext
    QRCode.js:使用 JavaScript 生成微信二维码
    SpringBoot整合Quartz
    java.lang.ClassNotFoundException: org.springframework.web.servlet.DispatcherServlet解决
    javascript简介
    mysql索引
  • 原文地址:https://www.cnblogs.com/agilestyle/p/12786022.html
Copyright © 2011-2022 走看看