zoukankan      html  css  js  c++  java
  • SVM的sklearn实现

    转载:豆-Metcalf

    1)SVM-LinearSVC.ipynb-线性分类SVM,iris数据集分类,正确率100%

     1 """
     2 功能:实现线性分类支持向量机
     3 说明:可以用于二类分类,也可以用于多类分类
     4 作者:唐天泽
     5 博客:http://write.blog.csdn.net/mdeditor#!postId=76188190
     6 日期:2017-08-09
     7 """
     8 
     9 # 导入本项目所需要的包
    10 import pandas as pd
    11 import numpy as np
    12 
    13 from sklearn import datasets
    14 
    15 from sklearn import svm
    16 
    17 # 使用交叉验证的方法,把数据集分为训练集合测试集
    18 from sklearn.model_selection import train_test_split
    19 
    20 # 加载iris数据集
    21 def load_data():
    22     iris = datasets.load_iris()
    23     """展示数据集的形状
    24        diabetes.data.shape, diabetes.target.shape
    25     """
    26 
    27     # 将数据集拆分为训练集和测试集 
    28     X_train, X_test, y_train, y_test = train_test_split(
    29     iris.data, iris.target, test_size=0.10, random_state=0)
    30     return X_train, X_test, y_train, y_test
    31 
    32 # 使用LinearSVC考察线性分类SVM的预测能力
    33 def test_LinearSVC(X_train,X_test,y_train,y_test):
    34 
    35     # 选择模型
    36     cls = svm.LinearSVC()
    37 
    38     # 把数据交给模型训练
    39     cls.fit(X_train,y_train)
    40 
    41     print('Coefficients:%s, intercept %s'%(cls.coef_,cls.intercept_))
    42     print('Score: %.2f' % cls.score(X_test, y_test))
    43 
    44 if __name__=="__main__":
    45     X_train,X_test,y_train,y_test=load_data() # 生成用于分类的数据集
    46     test_LinearSVC(X_train,X_test,y_train,y_test) # 调用 test_LinearSVC

    2) SVM-LinearSVC-kaggle.ipynb-线性分类SVM,手写数字数据集分类,正确率85%

     1 """
     2 功能:实现线性分类支持向量机
     3 说明:可以用于二类分类,也可以用于多类分类
     4 作者:唐天泽
     5 博客:http://write.blog.csdn.net/mdeditor#!postId=76188190
     6 日期:2017-08-09
     7 """
     8 
     9 # 导入本项目所需要的包
    10 import pandas as pd
    11 import numpy as np
    12 
    13 from sklearn import datasets
    14 
    15 from sklearn import svm
    16 
    17 # 使用交叉验证的方法,把数据集分为训练集合测试集
    18 from sklearn.model_selection import train_test_split
    19 
    20 # The competition datafiles are in the directory ../input
    21 # 加载数据集
    22 def load_data():
    23     dataset = pd.read_csv("~/Desktop/knn/input/train.csv")
    24     label = dataset.values[:,0]
    25     train = dataset.values[:,1:]
    26     testdata  = pd.read_csv("~/Desktop/knn/input/test.csv").values
    27     return label,train,testdata
    28  # 使用LinearSVC考察线性分类SVM的预测能力
    29 def test_LinearSVC(label,train,testdata):
    30 
    31     # 选择模型
    32     cls = svm.LinearSVC()
    33 
    34     # 把数据交给模型训练
    35     cls.fit(train,label)
    36 
    37     # 预测数据
    38     #print(cls.predict(testdata))
    39     results=cls.predict(testdata)
    40     return results
    41 
    42 if __name__=="__main__":
    43     label,train,testdata = load_data()
    44     result = test_LinearSVC(label,train,testdata)
    45     pd.DataFrame({"ImageId": list(range(1,len(testdata)+1)),"Label": result}).to_csv(
    46     '~/Desktop/knn/output/Digit_recogniser_SVM_LinearSVC.csv', index=False,header=True)

    补充:

     1 from sklearn import  svm
     2 
     3 from sklearn.datasets import  load_iris
     4 
     5 from sklearn.model_selection import train_test_split
     6 
     7 datas = load_iris()
     8 # print(datas)
     9 data_x = datas.data
    10 data_y = datas.target
    11 # print(data_x)
    12 
    13 #print(data_y)
    14 
    15 x_train,x_test,y_train,y_test = train_test_split(data_x,data_y,test_size=0.3)
    16 
    17 clf = svm.SVC()#默认核函数是高斯核
    18 # print(clf)
    19 clf = clf.fit(x_train,y_train)
    20 print(clf.predict(x_test))
    21 print(y_test)
  • 相关阅读:
    UVa 1349 (二分图最小权完美匹配) Optimal Bus Route Design
    UVa 1658 (拆点法 最小费用流) Admiral
    UVa 11082 (网络流建模) Matrix Decompressing
    UVa 753 (二分图最大匹配) A Plug for UNIX
    UVa 1451 (数形结合 单调栈) Average
    UVa 1471 (LIS变形) Defense Lines
    UVa 11572 (滑动窗口) Unique Snowflakes
    UVa 1606 (极角排序) Amphiphilic Carbon Molecules
    UVa 11054 Wine trading in Gergovia
    UVa 140 (枚举排列) Bandwidth
  • 原文地址:https://www.cnblogs.com/fuqia/p/9067463.html
Copyright © 2011-2022 走看看