zoukankan      html  css  js  c++  java
  • LibSVM C/C++


    本系列文章由 @YhL_Leo 出品,转载请注明出处。
    文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779


    LibSVM的库的svm.h头文件中定义了四个主要结构体:

    1 训练模型的结构体

    struct svm_problem
    {
        int l;                // total number of samples
        double *y;            // label of each sample
        struct svm_node **x;  // feature vector of each sample
    };

    样本的类别通常使用+1-1进行标识。如果样本的类别,则分类的准确率也就无法计算。

    2 数据节点的结构体

    struct svm_node
    {
        int index;     
        double value;
    };

    数据组织结构如图1所示:

    3 模型参数结构体

    struct svm_parameter
    {
        int svm_type;
        int kernel_type;
        int degree; /* for poly */
        double gamma;   /* for poly/rbf/sigmoid */
        double coef0;   /* for poly/sigmoid */
    
        /* these are for training only */
        double cache_size; /* in MB */
        double eps; /* stopping criteria */
        double C;   /* for C_SVC, EPSILON_SVR and NU_SVR */
        int nr_weight;      /* for C_SVC */
        int *weight_label;  /* for C_SVC */
        double* weight;     /* for C_SVC */
        double nu;  /* for NU_SVC, ONE_CLASS, and NU_SVR */
        double p;   /* for EPSILON_SVR */
        int shrinking;  /* use the shrinking heuristics */
        int probability; /* do probability estimates */
    };

    其中,各个参数的含义为:

    -s svm_type : set type of SVM (default 0)
        0 -- C-SVC
        1 -- nu-SVC
        2 -- one-class SVM
        3 -- epsilon-SVR
        4 -- nu-SVR
    -t kernel_type : set type of kernel function (default 2)
        0 -- linear: u'*v
        1 -- polynomial: (gamma*u'*v + coef0)^degree
        2 -- radial basis function: exp(-gamma*|u-v|^2)
        3 -- sigmoid: tanh(gamma*u'*v + coef0)
    -d degree : set degree in kernel function (default 3)
    -g gamma : set gamma in kernel function (default 1/num_features)
    -r coef0 : set coef0 in kernel function (default 0)
    -c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
    -n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
    -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
    -m cachesize : set cache memory size in MB (default 100)
    -e epsilon : set tolerance of termination criterion (default 0.001)
    -h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
    -b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
    -wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)

    SVM模型类型和核函数类型:

    enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */
    enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */

    4 训练输出模型结构体

    struct svm_model
    {
        struct svm_parameter param; /* parameter */
        int nr_class;       /* number of classes, = 2 in regression/one class svm */
        int l;          /* total #SV */
        struct svm_node **SV;       /* SVs (SV[l]) */
        double **sv_coef;   /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
        double *rho;        /* constants in decision functions (rho[k*(k-1)/2]) */
        double *probA;      /* pariwise probability information */
        double *probB;
        int *sv_indices;        /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */
    
        /* for classification only */
    
        int *label;     /* label of each class (label[k]) */
        int *nSV;       /* number of SVs for each class (nSV[k]) */
                    /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
        /* XXX */
        int free_sv;        /* 1 if svm_model is created by svm_load_model*/
                    /* 0 if svm_model is created by svm_train */
    };

    5 使用方法

    LibSVM提供的样本特征集heart_scale为例,首先需要读取样本特征数据,可以利用svm-train.c文件中的read_problem函数,为了方便使用,对其进行了重写改写:

    // TrainingDataLoad.h
    /*
        Load training data from svm format file.
    
        - Editor: Yahui Liu.
        - Data:   2015-11-30
        - Email:  yahui.cvrs@gmail.com
        - Address: Computer Vision and Remote Sensing(CVRS), Lab.
    **/
    
    #ifndef TRAINING_DATA_LOAD_H
    #define TRAINING_DATA_LOAD_H
    #pragma once
    
    #include <stdio.h>
    #include <stdlib.h>
    #include <ctype.h>
    #include <iostream>
    #include <vector>
    #include <string>
    #include <fstream>
    #include <errno.h>
    
    #include "svm.h"
    //#include "svm-scale.c"
    
    using namespace std;
    
    #define MAX_LINE_LEN 1024
    
    class TrainingDateLoad
    {
    public:
        TrainingDateLoad()
        {
            line = NULL;
        }
    
        ~TrainingDateLoad()
        {
            line = NULL;
        }
    
    public:
        char* line;
    
    // public:
    //  static struct svm_parameter _paramInit;
    
    public:
    
        /*! load svm model */
        void loadModel( std::string filename,  struct svm_model*& model);
        /*! skip the target */
        void svmSkipTarget( char*& p);
        /* skip the element */
        void svmSkipElement( char*& p);
    
        void initialParams( struct svm_parameter& param );  
        /*! load training data */
        void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param );
    
        char* readline(FILE *input); 
    
        void exit_input_error(int line_num)
        {
            cout << "Wrong input format at line: " << line_num << endl;
            exit(1);
        }
    
    };
    
    #endif // TRAINING_DATA_LOAD_H
    // TrainingDataLoad.cpp
    #include "TrainingDataLoad.h"
    
    void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model)
    {
        model = svm_load_model(filename.c_str());
    }
    
    void TrainingDateLoad::svmSkipTarget(char*& p)
    {
        while(isspace(*p)) ++p;
    
        while(!isspace(*p)) ++p;
    }
    
    void TrainingDateLoad::svmSkipElement(char*& p)
    {
        while(*p!=':') ++p;
    
        ++p;
        while(isspace(*p)) ++p;
        while(*p && !isspace(*p)) ++p;
    }
    
    void TrainingDateLoad::initialParams( struct svm_parameter& param )
    {
        // default values
        param.svm_type = C_SVC;
        param.kernel_type = RBF;
        param.degree = 3;
        param.gamma = 0;    // 1/num_features
        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 100;
        param.C = 1;
        param.eps = 1e-3;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = NULL;
        param.weight = NULL;
    }
    
    void TrainingDateLoad::readProblem( std::string filename, 
        struct svm_problem& prob, struct svm_parameter& param )
    {
        int max_index, inst_max_index, i;
        size_t elements, j;
        FILE *fp = fopen(filename.c_str(),"r");
        char *endptr;
        char *idx, *val, *label;
    
        if(fp == NULL)
        {
            fprintf(stderr,"can't open input file %s
    ",filename);
            exit(1);
        }
    
        prob.l = 0;
        elements = 0;
    
        line = new char[MAX_LINE_LEN];
        while(readline(fp)!=NULL)
        {
            char *p = strtok(line," 	"); // label
    
            // features
            while(1)
            {
                p = strtok(NULL," 	");
                if(p == NULL || *p == '
    ') // check '
    ' as ' ' may be after the last feature
                    break;
                ++elements;
            }
            ++elements;
            ++prob.l;
        }
        rewind(fp);
    
        prob.y = new double[prob.l];
        prob.x = new struct svm_node *[prob.l];
        struct svm_node *x_space = new struct svm_node[elements];
    
        max_index = 0;
        j=0;
        for(i=0;i<prob.l;i++)
        {
            inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
            readline(fp);
            prob.x[i] = &x_space[j];
            label = strtok(line," 	
    ");
            if(label == NULL) // empty line
                exit_input_error(i+1);
    
            prob.y[i] = strtod(label,&endptr);
            if(endptr == label || *endptr != '')
                exit_input_error(i+1);
    
            while(1)
            {
                idx = strtok(NULL,":");
                val = strtok(NULL," 	");
    
                if(val == NULL)
                    break;
    
                errno = 0;
                x_space[j].index = (int) strtol(idx,&endptr,10);
                if(endptr == idx || errno != 0 || *endptr != '' || x_space[j].index <= inst_max_index)
                    exit_input_error(i+1);
                else
                    inst_max_index = x_space[j].index;
    
                errno = 0;
                x_space[j].value = strtod(val,&endptr);
                if(endptr == val || errno != 0 || (*endptr != '' && !isspace(*endptr)))
                    exit_input_error(i+1);
    
                ++j;
            }
    
            if(inst_max_index > max_index)
                max_index = inst_max_index;
            x_space[j++].index = -1;
        }
    
        if(param.gamma == 0 && max_index > 0)
            param.gamma = 1.0/max_index;
    
        if(param.kernel_type == PRECOMPUTED)
            for(i=0;i<prob.l;i++)
            {
                if (prob.x[i][0].index != 0)
                {
                    fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number
    ");
                    exit(1);
                }
                if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
                {
                    fprintf(stderr,"Wrong input format: sample_serial_number out of range
    ");
                    exit(1);
                }
            }
    
            fclose(fp);
    }
    
    char* TrainingDateLoad::readline(FILE *input)
    {
        int len;
        if(fgets(line,MAX_LINE_LEN,input) == NULL)
            return NULL;
    
        int max_line_len = MAX_LINE_LEN;
        while(strrchr(line,'
    ') == NULL)
        {
            max_line_len *= 2;
            line = (char *) realloc(line,max_line_len);
            len = (int) strlen(line);
            if(fgets(line+len,max_line_len-len,input) == NULL)
                break;
        }
        return line;
    }

    将样本训练与预测进行改写:

    // LibSVMTools.h
    /*
        LibSVM train and predict tools.
    
        - Editor: Yahui Liu.
        - Data:   2015-12-3
        - Email:  yahui.cvrs@gmail.com
        - Address: Computer Vision and Remote Sensing(CVRS), Lab.
    **/
    
    #ifndef LIBSVM_TOOL_H
    #define LIBSVM_TOOL_H
    #pragma once
    
    #include <iostream>
    #include <string>
    
    #include "svm.h"
    #include "TrainingDataLoad.h"
    
    class LibSVMTools
    {
    public:
        LibSVMTools(){}
        ~LibSVMTools(){}
    
    public:
        /*!
            - featureFile: features of images saved in libsvm format.
            - saveModelFile: save the trained model file.
        **/
        void libSvmTrain(std::string featureFile, std::string saveModelFile);
    
        /*!
            - featureFile: features of images saved in libsvm format.
            - modelFile: libsvm trained model.
            - savePredictFile: save the predicting results.
        **/
        void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);
    };
    
    #endif // LIBSVM_TOOL_H
    // LibSVMTools.cpp
    #include "LibSVMTools.h"
    
    void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile)
    {
        struct svm_parameter param;
        struct svm_problem prob;
    
        TrainingDateLoad* trainData = new TrainingDateLoad;
        trainData->initialParams( param );
        trainData->readProblem(featureFile, prob, param);
    
        const char*errorMsg = svm_check_parameter(&prob, &param);
        if ( errorMsg )
        {
            cout << errorMsg << endl;
            return;
        }
    
        struct svm_model *model = svm_train(&prob, &param);
    
    #if 1
        cout << "svm_type: " << model->param.svm_type << endl <<
            "kernel_type: " << model->param.kernel_type << endl <<
            "gamma: " << model->param.gamma << endl <<
            "nr_class: " << model->nr_class << endl <<
            "total_sv: " << model->l << endl <<
            "rho: " << model->rho[0] << endl <<
            "label: " << model->label[0] << " " << model->label[1] << endl <<
            "nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;
    #endif
    
    
        int saveModel = svm_save_model( saveModelFile.c_str(), model );
    }
    
    void LibSVMTools::libSvmPredict(std::string featureFile, 
        std::string modelFile, std::string savePredictFile)
    {
        struct svm_parameter param;
        struct svm_problem prob;
    
        TrainingDateLoad * trainData = new TrainingDateLoad;
        trainData->initialParams( param );
        trainData->readProblem(featureFile, prob, param);
    
        struct svm_model* model;
        trainData->loadModel(modelFile.c_str(), model);
    
        float correct(0.0);     // all correct
        float uncorrect_1(0.0); // pos to neg
        float uncorrect_2(0.0); // neg to pos
        if ( prob.l )
        {
            const int nCount = prob.l;;
    
            ofstream outfile( savePredictFile, ios::out );
            for( int i=0; i<nCount; i++ )
            {
                double label = svm_predict(model, prob.x[i]);
                if ( label == prob.y[i] )
                {
                    correct ++;
                }
                else if ( label == -1.0 )
                {
                    uncorrect_1 ++;
                }
                else
                {
                    uncorrect_2 ++;
                }
                outfile << label << endl;
            }
    #if 1
            cout << "total data count: " << nCount << endl <<
                "classification correct: " << correct << endl <<
                "pos to neg count: " << uncorrect_1 << endl <<
                "neg to pos count: " << uncorrect_2 << endl;
    
            cout << "Accuracy: " << static_cast<float>(correct/nCount) 
                << "(" << correct << "/" << nCount << ")" << endl;
    #endif
            outfile.close();
        }
    }

    用例Demo:

    // train
    #include "LibSVMTools.h"
    
    void main()
    {
        std::cout << 
            "************************************************************" << endl <<
            "**          PROGRAM: LibSVM model training.               **" << endl << 
            "**                                                        **" << endl <<
            "**           Author: Yahui Liu.                           **" << endl <<
            "**                   School of Remote Sensing & Inf. Eng. **" << endl << 
            "**                   Wuhan University, Hubei, P.R. China  **" << endl <<
            "**            Email: yahui.cvrs@gmail.com                 **" << endl <<
            "**      Create time: Dec. 1, 2015                         **" << endl <<
            "************************************************************" << endl;
    
        string filename = "..\..\..\Data\heat_scale";
        std::string savefielname = "..\..\..\Data\train.model";
    
        LibSVMTools* libsvm = new LibSVMTools();
        libsvm->libSvmTrain(filename, savefielname);
    
        delete libsvm;
    }
    
    /*------------------------------------------------------------------------------------*/
    
    // predict
    #include "LibSVMTools.h"
    
    void main()
    {
        std::cout << 
            "************************************************************" << endl <<
            "**          PROGRAM: LibSVM predict.                      **" << endl << 
            "**                                                        **" << endl <<
            "**           Author: Yahui Liu.                           **" << endl <<
            "**                   School of Remote Sensing & Inf. Eng. **" << endl << 
            "**                   Wuhan University, Hubei, P.R. China  **" << endl <<
            "**            Email: yahui.cvrs@gmail.com                 **" << endl <<
            "**      Create time: Dec. 1, 2015                         **" << endl <<
            "************************************************************" << endl;
    
        std::string featureFile = "..\..\..\Data\heart_scale";
        std::string modelFile = "..\..\..\Data\train.model";
        std::string savePredictFile = "..\..\..\Data\predict.out";
    
        LibSVMTools* libsvm = new LibSVMTools();
        libsvm->libSvmPredict(featureFile, modelFile, savePredictFile);
    
        delete libsvm;
    }
  • 相关阅读:
    设置跨域
    Vs自定nuget push菜单
    VS IIS Express 支持局域网访问
    字符串GZIP压缩解压
    C# 使用 protobuf 进行对象序列化与反序列化
    RabbitMQ
    如果调用.net core Web API不能发送PUT/DELETE请求怎么办?
    log4net配置使用
    redis实现消息队列
    Error-the resource is not on the build path of a java project
  • 原文地址:https://www.cnblogs.com/hehehaha/p/6332210.html
Copyright © 2011-2022 走看看