zoukankan      html  css  js  c++  java
  • SVM核函数功能和选择——可视化 附源代码

     1 # coding: utf-8
     2 
     3 # In[3]:
     4 
     5 
     6 import numpy as np
     7 import matplotlib.pyplot as plt
     8 from matplotlib.colors import ListedColormap 
     9 from sklearn import svm
    10 from sklearn.datasets import make_circles, make_moons, make_blobs,make_classification
    11 
    12 
    13 # In[4]:
    14 
    15 
    16 n_samples = 100
    17 
    18 datasets = [
    19     make_moons(n_samples=n_samples, noise=0.2, random_state=0), 
    20     make_circles(n_samples=n_samples, noise=0.2, factor=0.5, random_state=1), 
    21     make_blobs(n_samples=n_samples, centers=2, random_state=5), make_classification(n_samples=n_samples,n_features =
    22     2,n_informative=2,n_redundant=0, random_state=5)
    23     ]
    24 
    25 Kernel = ["linear","poly","rbf","sigmoid"] #四个数据集分别是什么样子呢?
    26 
    27 for X,Y in datasets: 
    28     plt.figure(figsize=(5,4))
    29     plt.scatter(X[:,0],X[:,1],c=Y,s=50,cmap="rainbow")
    30 
    31 
    32 # In[14]:
    33 
    34 
    35 nrows=len(datasets) 
    36 ncols=len(Kernel) + 1
    37 
    38 fig, axes = plt.subplots(nrows, ncols,figsize=(20,16))
    39 
    40 for X,Y in datasets:
    41     plt.figure(figsize=(5,4))
    42     plt.scatter(X[:,0],X[:,1],c=Y,s=50,cmap="rainbow")
    43 
    44 #第一层循环:在不同的数据集中循环
    45 for ds_cnt, (X,Y) in enumerate(datasets):
    46 
    47     #在图像中的第一列,放置原数据的分布
    48     ax = axes[ds_cnt, 0]
    49     if ds_cnt == 0: 
    50         ax.set_title("Input data")
    51     ax.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired,edgecolors='k') 
    52     ax.set_xticks(())
    53     ax.set_yticks(())
    54 
    55     #第二层循环:在不同的核函数中循环
    56     #从图像的第二列开始,一个个填充分类结果
    57     for est_idx, kernel in enumerate(Kernel):
    58 
    59         #定义子图位置
    60         ax = axes[ds_cnt, est_idx + 1]
    61 
    62         #建模
    63         clf = svm.SVC(kernel=kernel, gamma=2).fit(X, Y)
    64         score = clf.score(X, Y)
    65 
    66         #绘制图像本身分布的散点图
    67         ax.scatter(X[:, 0], X[:, 1], c=Y
    68         ,zorder=10
    69         ,cmap=plt.cm.Paired,edgecolors='k')
    70         #绘制支持向量
    71         ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=50, 
                  facecolors='none', zorder=10, edgecolors='k') 72 73 #绘制决策边界 74 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 75 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 76 77 #np.mgrid,合并了我们之前使用的np.linspace和np.meshgrid的用法 #一次性使用最大值和最小值来生成网格 78 #表示为[起始值:结束值:步长] 79 #如果步长是复数,则其整数部分就是起始值和结束值之间创建的点的数量,并且结束值被包含在内 80 XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j] #np.c_,类似于np.vstack的功能 81 Z = clf.decision_function(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape) #填充等高线不同区域的颜色 82 ax.pcolormesh(XX, YY, Z > 0, cmap=plt.cm.Paired) #绘制等高线 83 ax.contour(XX, YY, Z, colors=['k', 'k', 'k'], linestyles=['--', '-', '--'], levels=[-1, 0, 1]) 84 85 #设定坐标轴为不显示ax.set_xticks(()) ax.set_yticks(()) 86 87 #将标题放在第一行的顶上if ds_cnt == 0: 88 ax.set_title(kernel) 89 90 #为每张图添加分类的分数 91 ax.text(0.95, 0.06, ('%.2f' % score).lstrip('0') 92 , size=15 93 , bbox=dict(boxstyle='round', alpha=0.8, facecolor='white') #为分数添加一个白色的格子作为底色 94 , transform=ax.transAxes #确定文字所对应的坐标轴,就是ax子图的坐标轴本身 95 , horizontalalignment='right' #位于坐标轴的什么方向 96 ) 97 98 plt.tight_layout() 99 plt.show()

       

  • 相关阅读:
    3D标签云
    IntelliJ IDEA 13.1.1版本偶然的错误
    414. Third Maximum Number
    217. Contains Duplicate
    442.Find All Duplicates in an Array
    3D轮播图
    448. Find All Numbers Disappeared in an Array
    Beautifulsoup模块
    MySQL数据库
    常用模块
  • 原文地址:https://www.cnblogs.com/ku1274755259/p/11140754.html
Copyright © 2011-2022 走看看