libsvm是support vector machine的一种开源实现,采用了smo算法。源代码编写有独到之处,值得一睹。
常用结构
svm_node结构
定义了构成输入特征向量的元素,index为索引(= -1为最后一个元素),value为值,
public class svm_node implements java.io.Serializable { public int index; public double value; }
借鉴了稀疏矩阵的表示方法。对于一个输入向量,定义为svm_node构成的一维数组
svm_node[] pa = {pa0, pa1};
所有输入序列有一个二维数组表示
svm_node[][] datas = {pa, pb};
标记序列
就是一个double数组,对应于输入序列datas的每一维。
double[] labels = {1.0, -1.0};
svm_problem结构
定义了(X, Y)的训练样本结构
public class svm_problem implements java.io.Serializable { public int l; public double[] y; public svm_node[][] x; }
其中l是样本数量。
svm_parameter结构
定义了训练时的重要参数
public class svm_parameter implements Cloneable,java.io.Serializable { /* svm_type */ public static final int C_SVC = 0; public static final int NU_SVC = 1; public static final int ONE_CLASS = 2; public static final int EPSILON_SVR = 3; public static final int NU_SVR = 4; /* kernel_type */ public static final int LINEAR = 0; public static final int POLY = 1; public static final int RBF = 2; public static final int SIGMOID = 3; public static final int PRECOMPUTED = 4; public int svm_type; public int kernel_type; public int degree; // for poly public double gamma; // for poly/rbf/sigmoid public double coef0; // for poly/sigmoid // these are for training only public double cache_size; // in MB public double eps; // stopping criteria public double C; // for C_SVC, EPSILON_SVR and NU_SVR public int nr_weight; // for C_SVC public int[] weight_label; // for C_SVC public double[] weight; // for C_SVC public double nu; // for NU_SVC, ONE_CLASS, and NU_SVR public double p; // for EPSILON_SVR public int shrinking; // use the shrinking heuristics public int probability; // do probability estimates public Object clone() { try { return super.clone(); } catch (CloneNotSupportedException e) { return null; } } }
主要分为两大类参数:分类器的核函数性质和训练算法SMO的一些参数,包括精度啊等等
训练
通过调用svm.svm_train()训练模型
public static svm_model svm_train(svm_problem prob, svm_parameter param)
返回svm_model类对象表示训练得到的分类器
预测
通过svm.svm_predict()利用分类器进行预测
public static double svm_predict(svm_model model, svm_node[] x)
返回类别标记
实例代码如下,输入点pa = (10.0 10.0) ya = 1.0 pb = (-10.0, -10.0) yb = -1.0
测试点 (-0.1, 0)
1 import libsvm.svm; 2 import libsvm.svm_model; 3 import libsvm.svm_node; 4 import libsvm.svm_parameter; 5 import libsvm.svm_problem; 6 7 public class SvmTest { 8 public static void main(String[] args) { 9 10 svm_node pa0 = new svm_node(); 11 pa0.index = 0; 12 pa0.value = 10.0; 13 14 svm_node pa1 = new svm_node(); 15 pa1.index = -1; 16 pa1.value = 10.0; 17 18 svm_node pb0 = new svm_node(); 19 pb0.index = 0; 20 pb0.value = -10.0; 21 22 svm_node pb1 = new svm_node(); 23 pb1.index = -1; 24 pb1.value = -10.0; 25 26 svm_node[] pa = {pa0, pa1}; 27 svm_node[] pb = {pb0, pb1}; 28 29 svm_node[][] datas = {pa, pb}; 30 31 double[] labels = {1.0, -1.0}; 32 33 svm_problem problem = new svm_problem(); 34 problem.l = 2; 35 problem.x = datas; 36 problem.y = labels; 37 38 svm_parameter param = new svm_parameter(); 39 param.svm_type = svm_parameter.C_SVC; 40 param.kernel_type = svm_parameter.LINEAR; 41 param.cache_size = 100; 42 param.eps = 0.00001; 43 param.C = 1; 44 45 46 System.out.println(svm.svm_check_parameter(problem, param)); 47 svm_model model = svm.svm_train(problem, param); 48 49 svm_node pc0 = new svm_node(); 50 pc0.index = 0; 51 pc0.value = -0.1; 52 svm_node pc1 = new svm_node(); 53 pc1.index = -1; 54 pc1.value = 0; 55 56 svm_node[] pc = {pc0, pc1}; 57 58 System.out.println(svm.svm_predict(model, pc)); 59 } 60 }