zoukankan      html  css  js  c++  java
  • libsvm使用简介

    libsvm是support vector machine的一种开源实现,采用了smo算法。源代码编写有独到之处,值得一睹。

    常用结构

    svm_node结构

    定义了构成输入特征向量的元素,index为索引(= -1为最后一个元素),value为值,

    public class svm_node implements java.io.Serializable
    {
        public int index;
        public double value;
    }

     借鉴了稀疏矩阵的表示方法。对于一个输入向量,定义为svm_node构成的一维数组

    svm_node[] pa = {pa0, pa1};

     所有输入序列有一个二维数组表示

    svm_node[][] datas = {pa, pb};

    标记序列

    就是一个double数组,对应于输入序列datas的每一维。

    double[] labels = {1.0, -1.0};

    svm_problem结构

    定义了(X, Y)的训练样本结构

    public class svm_problem implements java.io.Serializable
    {
        public int l;
        public double[] y;
        public svm_node[][] x;
    }

    其中l是样本数量。

    svm_parameter结构

    定义了训练时的重要参数

    public class svm_parameter implements Cloneable,java.io.Serializable
    {
        /* svm_type */
        public static final int C_SVC = 0;
        public static final int NU_SVC = 1;
        public static final int ONE_CLASS = 2;
        public static final int EPSILON_SVR = 3;
        public static final int NU_SVR = 4;
    
        /* kernel_type */
        public static final int LINEAR = 0;
        public static final int POLY = 1;
        public static final int RBF = 2;
        public static final int SIGMOID = 3;
        public static final int PRECOMPUTED = 4;
    
        public int svm_type;
        public int kernel_type;
        public int degree;    // for poly
        public double gamma;    // for poly/rbf/sigmoid
        public double coef0;    // for poly/sigmoid
    
        // these are for training only
        public double cache_size; // in MB
        public double eps;    // stopping criteria
        public double C;    // for C_SVC, EPSILON_SVR and NU_SVR
        public int nr_weight;        // for C_SVC
        public int[] weight_label;    // for C_SVC
        public double[] weight;        // for C_SVC
        public double nu;    // for NU_SVC, ONE_CLASS, and NU_SVR
        public double p;    // for EPSILON_SVR
        public int shrinking;    // use the shrinking heuristics
        public int probability; // do probability estimates
    
        public Object clone() 
        {
            try 
            {
                return super.clone();
            } catch (CloneNotSupportedException e) 
            {
                return null;
            }
        }
    
    }

    主要分为两大类参数:分类器的核函数性质和训练算法SMO的一些参数,包括精度啊等等

    训练

    通过调用svm.svm_train()训练模型

    public static svm_model svm_train(svm_problem prob, svm_parameter param)

    返回svm_model类对象表示训练得到的分类器

    预测

    通过svm.svm_predict()利用分类器进行预测

    public static double svm_predict(svm_model model, svm_node[] x)

    返回类别标记

    实例代码如下,输入点pa = (10.0 10.0) ya = 1.0 pb = (-10.0, -10.0) yb = -1.0

    测试点 (-0.1, 0)

     1 import libsvm.svm;
     2 import libsvm.svm_model;
     3 import libsvm.svm_node;
     4 import libsvm.svm_parameter;
     5 import libsvm.svm_problem;
     6 
     7 public class SvmTest {
     8     public static void main(String[] args) {
     9         
    10         svm_node pa0 = new svm_node();
    11         pa0.index = 0;
    12         pa0.value = 10.0;
    13         
    14         svm_node pa1 = new svm_node();
    15         pa1.index = -1;
    16         pa1.value = 10.0;
    17         
    18         svm_node pb0 = new svm_node();
    19         pb0.index = 0;
    20         pb0.value = -10.0;
    21         
    22         svm_node pb1 = new svm_node();
    23         pb1.index = -1;
    24         pb1.value = -10.0;
    25         
    26         svm_node[] pa = {pa0, pa1};
    27         svm_node[] pb = {pb0, pb1};
    28         
    29         svm_node[][] datas = {pa, pb};
    30         
    31         double[] labels = {1.0, -1.0};
    32         
    33         svm_problem problem = new svm_problem();
    34         problem.l = 2;
    35         problem.x = datas;
    36         problem.y = labels;
    37         
    38         svm_parameter param = new svm_parameter();
    39         param.svm_type = svm_parameter.C_SVC;
    40         param.kernel_type = svm_parameter.LINEAR;
    41         param.cache_size = 100;
    42         param.eps = 0.00001;
    43         param.C = 1;
    44         
    45         
    46         System.out.println(svm.svm_check_parameter(problem, param));
    47         svm_model model = svm.svm_train(problem, param);
    48         
    49         svm_node pc0 = new svm_node();
    50         pc0.index = 0;
    51         pc0.value = -0.1;
    52         svm_node pc1 = new svm_node();
    53         pc1.index = -1;
    54         pc1.value = 0;
    55         
    56         svm_node[] pc = {pc0, pc1};
    57         
    58         System.out.println(svm.svm_predict(model, pc));
    59     }
    60 }
  • 相关阅读:
    阿里消息队列中间件 RocketMQ 源码分析 —— Message 拉取与消费(上)
    数据库中间件 ShardingJDBC 源码分析 —— SQL 解析(三)之查询SQL
    数据库分库分表中间件 ShardingJDBC 源码分析 —— SQL 解析(六)之删除SQL
    数据库分库分表中间件 ShardingJDBC 源码分析 —— SQL 解析(五)之更新SQL
    消息队列中间件 RocketMQ 源码分析 —— Message 存储
    源码圈 300 胖友的书单整理
    数据库分库分表中间件 ShardingJDBC 源码分析 —— SQL 路由(一)分库分表配置
    数据库分库分表中间件 ShardingJDBC 源码分析 —— SQL 解析(四)之插入SQL
    数据库分库分表中间件 ShardingJDBC 源码分析 —— SQL 路由(二)之分库分表路由
    C#中Math类的用法
  • 原文地址:https://www.cnblogs.com/zjgtan/p/3305720.html
Copyright © 2011-2022 走看看