zoukankan      html  css  js  c++  java
  • 【机器学习】支持向量机SVM

    参考博客:《学习SVM,这篇文章就够了!(附详细代码)》

    https://www.jiqizhixin.com/articles/2018-10-17-20

    公众号文章:【白话机器学习】算法理论+实战之支持向量机(SVM)

    https://mp.weixin.qq.com/s?__biz=MzIwODI2NDkxNQ==&mid=2247487899&idx=5&sn=9ebc893c29831f0b800554e6f613c03d&chksm=97049a27a0731331115fee11d4f27d45d5e34e2da0785eece300514b4ee5af248631c67dfe55&scene=21#wechat_redirect

    B站:机器学习-白板推导系列(六)-支持向量机SVM(Support Vector Machine)

    https://www.bilibili.com/video/av28186618?from=search&seid=3278332437610810809

    理论以后补充

    数据集下载地址:

    链接:https://pan.baidu.com/s/1SE1K3mw7Xam_zi-1i9Axug 
    提取码:xa82 
    数据集来自美国威斯康星州的乳腺癌诊断数据集,也可在该数据集可在UCI数据库:http://mlr.cs.umass.edu/ml/machine-learning-databases/breast-cancer-wisconsin/中找到

    以下是【白话机器学习】算法理论+实战之支持向量机(SVM)中的代码复现:

     1 from sklearn import svm
     2 import numpy as np
     3 import pandas as pd
     4 import matplotlib.pyplot as plt
     5 import seaborn as sns
     6 
     7 from sklearn.svm import SVC, LinearSVC
     8 from sklearn.model_selection import train_test_split
     9 from sklearn.preprocessing import StandardScaler
    10 from sklearn.metrics import accuracy_score
    11 
    12 # 加载数据集
    13 data = pd.read_csv("data/wdbc.data.csv")
    14 # 数据探索
    15 # 因为数据集中列比较多,我们需要把dataframe中的列全部显示出来
    16 pd.set_option('display.max_columns', None)
    17 #print('data.columns', data.columns)
    18 #print('data.head(5)', data.head(5))
    19 #print('data.describe()', data.describe())
    20 '''
    21 data.columns Index(['id', 'diagnosis', 'radius_mean', 'texture_mean', 'perimeter_mean',
    22        'area_mean', 'smoothness_mean', 'compactness_mean', 'concavity_mean',
    23        'concave points_mean', 'symmetry_mean', 'fractal_dimension_mean',
    24        'radius_se', 'texture_se', 'perimeter_se', 'area_se', 'smoothness_se',
    25        'compactness_se', 'concavity_se', 'concave points_se', 'symmetry_se',
    26        'fractal_dimension_se', 'radius_worst', 'texture_worst',
    27        'perimeter_worst', 'area_worst', 'smoothness_worst',
    28        'compactness_worst', 'concavity_worst', 'concave points_worst',
    29        'symmetry_worst', 'fractal_dimension_worst'],
    30       dtype='object')
    31 '
    32 data.head(5)          
    33          id diagnosis  radius_mean  texture_mean  perimeter_mean  area_mean   ......
    34 0    842302         M        17.99         10.38          122.80     1001.0
    35 1    842517         M        20.57         17.77          132.90     1326.0
    36 2  84300903         M        19.69         21.25          130.00     1203.0
    37 3  84348301         M        11.42         20.38           77.58      386.1
    38 4  84358402         M        20.29         14.34          135.10     1297.0
    39 
    40 data.describe()                  
    41                  id  radius_mean  texture_mean  perimeter_mean    area_mean   ......
    42 count  5.690000e+02   569.000000    569.000000      569.000000   569.000000
    43 mean   3.037183e+07    14.127292     19.289649       91.969033   654.889104
    44 std    1.250206e+08     3.524049      4.301036       24.298981   351.914129
    45 min    8.670000e+03     6.981000      9.710000       43.790000   143.500000
    46 25%    8.692180e+05    11.700000     16.170000       75.170000   420.300000
    47 50%    9.060240e+05    13.370000     18.840000       86.240000   551.100000
    48 75%    8.813129e+06    15.780000     21.800000      104.100000   782.700000
    49 max    9.113205e+08    28.110000     39.280000      188.500000  2501.000000
    50 
    51 '''
    52 # 将特征字段分成3组
    53 features_mean = list(data.columns[2:12])
    54 features_se = list(data.columns[12:22])
    55 features_worst = list(data.columns[22:32])
    56 # 数据清洗
    57 # ID列没有用,删除该列
    58 data.drop("id",axis=1,inplace=True)
    59 # 将B良性替换为0,M恶性替换为1
    60 data['diagnosis'] = data['diagnosis'].map({'M':1,'B':0})
    61 
    62 # 将肿瘤诊断结果可视化
    63 sns.countplot(data['diagnosis'],label='Count')
    64 plt.savefig('./result/diagnosis.jpg')
    65 plt.show()
    66 
    67 # 用热力图呈现features_mean字段之间的相关性
    68 corr = data[features_mean].corr()
    69 plt.figure(figsize=(14,14))
    70 # annot=True显示每个方格的数据
    71 sns.heatmap(corr, annot=True)
    72 plt.savefig('./result/heatmap.jpg')
    73 plt.show()
    74 
    75 # 特征选择
    76 features_remain = ['radius_mean','texture_mean','smoothness_mean','compactness_mean','symmetry_mean','fractal_dimension_mean']
    77 # 抽取30%的数据作为测试集,其余作为训练集
    78 train,test = train_test_split(data,test_size = 0.3)#in this our main data is splitted into train and test
    79 # 抽取特征选择的数值作为训练和测试数据
    80 train_X = train[features_remain]
    81 train_y = train['diagnosis']
    82 test_X = test[features_remain]
    83 test_y = test['diagnosis']
    84 # 采用Z-Score规范化数据,保证每个特征维度的数据均值为0,方差为1
    85 ss = StandardScaler()
    86 train_X = ss.fit_transform(train_X)
    87 test_X = ss.transform(test_X)
    88 
    89 # 创建SVM分类器
    90 model = svm.SVC()
    91 # 用训练集做训练
    92 model.fit(train_X,train_y)
    93 # 用测试集做预测
    94 prediction = model.predict(test_X)
    95 print('准确率:', accuracy_score(prediction,test_y)) #准确率: 0.9415204678362573

     

  • 相关阅读:
    博客园的第一篇博客
    I-如何办好比赛
    塞特斯玛斯塔
    字典序最大的子序列
    百练POJ 1657:Distance on Chessboard
    百练POJ2750:鸡兔同笼
    HDU3790最短路径问题
    HDU 2544最短路Dijkstra算法
    快速幂【倍增+二分】
    树的高度
  • 原文地址:https://www.cnblogs.com/DJames23/p/12539651.html
Copyright © 2011-2022 走看看