zoukankan      html  css  js  c++  java
  • 【364】SVM 通过 sklearn 可视化实现

    先看下效果图:

    # 先调入需要的模块
    
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import svm
    import seaborn as sb
    
    # 生成几个数据点
    
    data = np.array([
        [0.1, 0.7],
        [0.3, 0.6],
        [0.4, 0.1],
        [0.5, 0.4],
        [0.8, 0.04],
        [0.42, 0.6],
        [0.9, 0.4],
        [0.6, 0.5],
        [0.7, 0.2],
        [0.7, 0.67],
        [0.27,0.8],
        [0.5, 0.72]
        ])
        
    
    target = [1] * 6 + [0] * 6
    
    x_line = np.linspace(0, 1, 100)
    y_line = 1 - x_line
    plt.scatter(data[:6, 0], data[:6, 1], marker='o', s=100, lw=3)
    plt.scatter(data[6:, 0], data[6:, 1], marker='x', s=100, lw=3)
    plt.plot(x_line, y_line)
    
    # 定义计算域、文字说明等
    
    C = 0.0001  # SVM regularization parameter, since Scikit-learn doesn't allow C=0
    # linear_svc = svm.SVC(kernel='linear', C=C).fit(data, target)
    
    # create a mesh to plot in
    h = 0.002
    x_min, x_max = data[:, 0].min() - 0.2, data[:, 0].max() + 0.2
    y_min, y_max = data[:, 1].min() - 0.2, data[:, 1].max() + 0.2
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    
    # title for the plots
    titles = ['SVC with linear kernel',
              'SVC with RBF kernel',
              'SVC with polynomial (degree 3) kernel']
    
    # RBF Kernel
    
    plt.figure(figsize=(16, 15))
    
    for i, gamma in enumerate([1, 5, 15, 35, 45, 55]):
        rbf_svc = svm.SVC(kernel='rbf', gamma=gamma, C=C).fit(data, target)
        
        # ravel - flatten
        # c_ - vstack
        # #把后面两个压扁之后变成了x1和x2,然后进行判断,得到结果在压缩成一个矩形
        Z = rbf_svc.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)
        
        plt.subplot(3, 2, i + 1)
        plt.subplots_adjust(wspace=0.4, hspace=0.4)
        plt.contourf(xx, yy, Z, cmap=plt.cm.ocean, alpha=0.6)
    
        # Plot the training points
        plt.scatter(data[:6, 0], data[:6, 1], marker='o', color='r', s=100, lw=3)
        plt.scatter(data[6:, 0], data[6:, 1], marker='x', color='k', s=100, lw=3)
        
        plt.title('RBF SVM with $gamma=$' + str(gamma))
        
    plt.show()
    
  • 相关阅读:
    Java-16 常用包装类
    Java-15 String类
    flask实现web端微信登录
    Flask-excel导出数据
    Java-14 单例模式
    Java-13 接口
    Java-12 抽象类
    结对项目:一个自动生成小学四则运算题目的命令行程序(c++)
    软工作业:第一次个人项目作业
    操作系统笔记一:操作系统引论
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/10334717.html
Copyright © 2011-2022 走看看