zoukankan      html  css  js  c++  java
  • 星球 CMClass: 簡述 libsvm(Support Vector Machine library) 使用方法

    星球 - CMClass: 簡述 libsvm(Support Vector Machine library) 使用方法 | Ubuntu 正體中文站

    星球 - CMClass: 簡述 libsvm(Support Vector Machine library) 使用方法

    作者:何岳峰 hoamon | 來自:hoamon's sandbox | 2011/4/27 15:35:00
    libsvm乃台大林智仁老師開發的 Open source 工具,其目的為實作 Support Vector Machine 分類器,使用語言主要是 C++ ,目前也有 JAVA 版本,也提供其他語言的 wrapper ,像是 Perl, Python, Ruby, Matlab, Hashkell, Lisp 等。

    詳細數學就不介紹了,怕大家睡著(但其實是因為還沒看懂),各位可以看一下下面那段這個影片,大略了解 SVM 分類器如何區別不同資料。

    [!IFRAME FILTERED!]

    本文章主要介紹的是用 Python 語言去操作 libsvm 函式庫。

    先解壓縮 libsvm.tgz 檔,可以看到 python 及 windows 資料夾,如果要在 Linux 中使用的話,請在主目錄中作

    $ make lib

    這樣會得到 libsvm.so.2 檔,這是 libsvm 的主函式庫,而在 windows 中使用的話,它則是先幫你編譯好這個檔了,可在 windows/ 找到這個 libsvm.dll 檔。

    在 Linux 中,請把 python/*py 放到 /usr/local/lib/python2.6/site-packages 中,而 libsvm.so.2 放到 /usr/local/lib/python2.6/ 。

    在 windows 中,請把 python/*py 放到 C:\Python26\Lib\site-packages 中,而 libsvm.dll 請放到 C:\Python26\Lib\windows 資料夾中(因為 svmutil.py 寫死了它的相對路徑,所以務必依它的相對位置置放)。

    請在 Python shell 中,鍵入下列指令,測試是否安裝成功。

    >>> from svmutil import *
    >>>

    沒錯誤訊息,那就是安裝對了。

    使用 svm ,主要就是兩個動作: 訓練及預測。

    訓練:

    svmutil.svm_train 函式的引數有「類別標籤」、「觀察值」、「參數」。

    你的原始資料若是如下:

    1. 3, 4, 5, 6 => 第二類
    2. 3, 4, 5, 5 => 第一類
    3. ....

    前面的 #. 表第幾個觀察值,後面逗號分隔的數據為各維度的值,行末則是放置該觀察值為第幾類的說明。請把它轉成

    >>> Y = [2, 1, ...]
    >>> x = [(3, 4, 5, 6), (3, 4, 5, 5), ...]

    類別標籤請獨立放置到一個 list 中,而觀察值維度則依序放置到另一個 list 中。接下來,就能使用 svm_train:

    >>> from svmutil import *
    >>> model = svm_train(Y, x, '-c 4')

    所得到的 model 就是一個經過訓練的分類器。

    預測

    接下來,我們要拿訓練好的分類器去預測新的觀察值:

    >>> p_label, p_acc, p_val = svm_predict([0]*len(new_x), new_x, model)

    而 p_label 就是依 new_x 順序所對應的類別標籤 list 。

    下圖是我隨機生成的 300 點,圓點為原始的觀察值,而以線相連的連續點則是預測點。



    詳細程式碼請參照如下:

     1 #! /usr/bin/python
    2 # -*- coding: utf8 -*-
    3
    4 __author__="hoamon"
    5 __date__ =u"$2011/4/12 下午 05:52:31$"
    6
    7 from math import pi, sin, cos
    8 from random import random
    9 from matplotlib import pyplot as plt
    10 from svmutil import *
    11
    12 def circleData(centre, radius, down_limit_percent=0, lens=100, range=[0, 100]):
    13 points = []
    14 while len(points) < lens:
    15 _angle = 2 * pi * random()
    16 radius_percent = random()
    17 if radius_percent < down_limit_percent: continue
    18 _radius = radius * radius_percent
    19 x = centre[0] + cos(_angle) * _radius
    20 y = centre[1] + sin(_angle) * _radius
    21 if range[0] <= x <= range[1] and range[0] <= y <= range[1]:
    22 points.append((x, y))
    23 return points
    24
    25
    26 def test():
    27 u""" 製作三群的隨機資料,每群皆 100 個點,點位置的 x, y 限制在 0 ~ 100 之間
    28
    29 最後利用 matplotlib 繪製出來的圖,"單點"表原始資料,而連續點畫線的部份,
    30 該點位的類別則是利用 svm_predict 計算出來的。
    31
    32 Y = [1, 1, 1, ..., 2, 2, 2, ..., 3, 3, 3, ...]
    33 x = [(x1, y1), (x2, y2), ...]
    34 """
    35 Y = [1] * 100 + [2] * 100 + [3] * 100
    36 x1, x2, x3 = (circleData((35, 40), 12),
    37 circleData((35, 40), 48, down_limit_percent=0.25),
    38 circleData((80, 80), 20)
    39 )
    40 x = x1 + x2 + x3
    41
    42 m = svm_train(Y, x, '-c 4')
    43
    44 #INFO 在 100x100 的畫布上,打出 40000 個點,拿這 4 萬個點去給 m 作預測,算出這 4 萬個點的類別
    45 points = [(i*0.5, j*0.5) for j in xrange(0, 200) for i in xrange(0, 200)]
    46 p_label, p_acc, p_val = svm_predict([0]*40000, points, m)
    47
    48 line_1, line_2, line_3, pre_label = [], [], [], p_label[0]
    49 for i in xrange(0, 200):
    50 for j in xrange(0, 200):
    51 index = i * 200 + j
    52 now_label = p_label[index]
    53 if now_label == 1 :
    54 line_1.append(points[index])
    55 elif now_label == 2 :
    56 line_2.append(points[index])
    57 elif now_label == 3 :
    58 line_3.append(points[index])
    59
    60 fig = plt.figure()
    61 ax = fig.add_subplot(111)
    62 ax.plot([p[0] for p in x1], [p[1] for p in x1], 'ro')
    63 ax.plot([p[0] for p in x2], [p[1] for p in x2], 'go')
    64 ax.plot([p[0] for p in x3], [p[1] for p in x3], 'bo')
    65 ax.plot([p[0] for p in line_1], [p[1] for p in line_1], 'r-', alpha=0.5)
    66 ax.plot([p[0] for p in line_3], [p[1] for p in line_3], 'b-', alpha=0.5)
    67 ax.set_title('Points of three classes')
    68 ax.set_xlabel('x')
    69 ax.set_ylabel('y')
    70 ax.set_xlim(0, 100)
    71 ax.set_ylim(0, 100)
    72 plt.show()
    73 return m, p_label, p_acc, p_val
    74
    75
    76 if __name__ == "__main__":
    77 test()
  • 相关阅读:
    java修饰符 protect public protected
    java中interface使用
    java中super的用法
    引用的一道JAVA题目
    java中==和equals的区别(转)
    2019PHP面试题最全面归纳总结
    (一)PHP基础知识考察点
    Linux常用命令大全(非常全!!!)
    MAMP mysql无法启动 总结(以后有发现再添加)
    win 安装composer (详细教程)
  • 原文地址:https://www.cnblogs.com/lexus/p/2845146.html
Copyright © 2011-2022 走看看