zoukankan      html  css  js  c++  java
  • [OpenCV随笔]-OpenCV3.x中SVM多分类使用(代码篇)

    1. SVM介绍

    占个坑,以后再说

    2. OpenCV3.x下SVM接口介绍

    官方文档
    OpenCV3.x与OpenCV2.x中SVM的接口有了很大变化,在接口上使用了虚函数取代以前的定义。
    下面介绍几个常用的接口,及其参数意义。

    2.1 初始化函数

    定义如下:

    CV_WRAP static Ptr<SVM> create();
    

    2.2 参数设置函数

    然后是一些设置SVM参数的函数:

    CV_WRAP virtual int getType() const = 0;
    CV_WRAP virtual void setType(int val) = 0;
    
    CV_WRAP virtual double getGamma() const = 0;
    CV_WRAP virtual void setGamma(double val) = 0;
    
    CV_WRAP virtual double getDegree() const = 0;
    CV_WRAP virtual void setDegree(double val) = 0;
    
    CV_WRAP virtual double getC() const = 0;
    CV_WRAP virtual void setC(double val) = 0;
    
    CV_WRAP virtual double getNu() const = 0;
    CV_WRAP virtual void setNu(double val) = 0;
    
    CV_WRAP virtual double getP() const = 0;
    CV_WRAP virtual void setP(double val) = 0;
    
    CV_WRAP virtual cv::Mat getClassWeights() const = 0;
    CV_WRAP virtual void setClassWeights(const cv::Mat &val) = 0;
    
    CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0;
    CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
    
    CV_WRAP virtual int getKernelType() const = 0;
    CV_WRAP virtual void setKernel(int kernelType) = 0;
    

    具体的作用可以参考OpenCV文档,这里只介绍两个常用的函数:

    //设置SVM类型
    CV_WRAP virtual int getType() const = 0;
    

    这个函数用于设置SVM类型,OpenCV提供了五种类型:

    Types { 
        //C类支持向量分类机。 n类分组 (n≥2),容许用异常值处罚因子C进行不完全分类。
        C_SVC =100, 
    
        //$v$类支持向量机
        NU_SVC =101, 
    
        //单分类器,所有的练习数据提取自同一个类里,
        //然后SVM建树了一个分界线以分别该类在特点空间
        //中所占区域和其它类在特点空间中所占区域。
        ONE_CLASS =102, 
    
        EPS_SVR =103, 
    
        NU_SVR =104 
    }
    

    一般我们使用SVM进行二分类或者多分类任务,选择第一种SVM::C_SVC即可。
    还有一个函数就是:

    CV_WRAP virtual void setKernel(int kernelType) = 0;
    

    这个函数用于设置SVM的核函数类型,我们知道,通过选择SVM的核函数可以使SVM处理高阶、非线性问题。OpenCV提供几种核函数:

    enum KernelTypes {
        /** Returned by SVM::getKernelType in case when custom kernel has been set */
        CUSTOM=-1,
        
        //线性核
        LINEAR=0,
        
        //多项式核
        POLY=1,
        
        //径向基核(高斯核)
        RBF=2,
        
        //sigmoid核
        SIGMOID=3,
        
        //指数核,与高斯核类似
        CHI2=4,
        
        //直方图核
        INTER=5
    };
    

    一般情况下使用径向基核可以很好处理大部分情况。

    2.3 训练函数

    OpenCV3.x中SVM的提供了训练函数也与2.x不同,如下:

    virtual bool trainAuto( const Ptr<TrainData>& data, int kFold = 10,
                    ParamGrid Cgrid = getDefaultGrid(C),
                    ParamGrid gammaGrid  = getDefaultGrid(GAMMA),
                    ParamGrid pGrid      = getDefaultGrid(P),
                    ParamGrid nuGrid     = getDefaultGrid(NU),
                    ParamGrid coeffGrid  = getDefaultGrid(COEF),
                    ParamGrid degreeGrid = getDefaultGrid(DEGREE),
                    bool balanced=false) = 0;
    
    bool trainAuto (InputArray samples, int layout, InputArray responses, 
                    int kFold=10, Ptr< ParamGrid > Cgrid=SVM::getDefaultGridPtr(SVM::C), 
                    Ptr< ParamGrid > gammaGrid=SVM::getDefaultGridPtr(SVM::GAMMA), 
                    Ptr< ParamGrid > pGrid=SVM::getDefaultGridPtr(SVM::P), 
                    Ptr< ParamGrid > nuGrid=SVM::getDefaultGridPtr(SVM::NU), 
                    Ptr< ParamGrid > coeffGrid=SVM::getDefaultGridPtr(SVM::COEF), 
                    Ptr< ParamGrid > degreeGrid=SVM::getDefaultGridPtr(SVM::DEGREE), 
                    bool balanced=false)
    

    trainAuto可以在训练过程中自动优化2.2中的那些参数,而使用train函数时,参数被固定,所以推荐使用trainAuto函数。
    在准备训练数据的时候,有下面几点需要注意,否则函数会报错

    1. SVM的训练函数是ROW_SAMPLE类型的,也就是说,送入SVM训练的特征需要reshape成一个行向量,所有训练数据全部保存在一个Mat中,一个训练样本就是Mat中的一行,最后还要讲这个Mat转换成CV_32F类型,例如,如果有(k)个样本,每个样本原本维度是((h, w)),则转换后Mat的维度为((k, h * w))
    2. 对于多分类问题,label矩阵的行数要与样本数量一致,也就是每个样本要在label矩阵中有一个对应的标签,label的列数为1,因为对于一个样本,SVM输出一个值,我们在训练前需要做的就是设计这个值与样本的对应关系。对于有(k)个样本的情况,label的维度是((k, 1))

    2.4 预测函数

    函数定义如下:

    float predict(cv::InputArrat samples, cv::OutputArray results = noArray(), int flags = 0) const;
    

    其中samples就是需要预测的样本,这里样本同样要转换成ROW_SAMPLE和CV_32F格式,对于单个测试样本的情况,预测结果直接通过函数返回值返回,而如果samples中有多个样本,就需要穿进result参数,预测结果以列向量的方式保存在result数组中。假如有(k)个样本,每个样本原本的维度为((h, w)),则samples的维度为((k, h * w)),最终预测结果result维度为((k, 1))

    3. 例程

    下面上代码:

    /*
    * 把图片从vector<Mat>格式转换成SVM的RAW_SAMPLE格式
    */
    void transform(const vector<Mat> &split, Mat &testData)
    {
        for (auto it = split.begin(); it != split.end(); it++){
            Mat tmp;
            resize(*it, tmp, Size(28, 28));
            testData.push_back(tmp.reshape(0, 1));
        }
    
        testData.convertTo(testData, CV_32F);
    }
    
    /*
    * 从文件list.txt中读取测试数据和标签,输出SVM的Mat格式
    */
    void get_data(string path, Mat &trainData, Mat &trainLabels)
    {
        fstream io(path, ios::in);
        if (!io.is_open()){
            cout << "file open error in path : " << path << endl;
            exit(0);
        }
    
        while (!io.eof())
        {
            string msg;
            io >> msg;
    
            trainData.push_back(imread(msg, 0).reshape(0, 1));
    
            io >> msg;
            int idx = msg[0] - '0';
            //trainLabels.push_back(Mat_<int>(1, 1) << idx);  //用这种方式会报错,原因尚且不明
            trainLabels.push_back(Mat(1, 1, CV_32S, &idx));
        }
    
        trainData.convertTo(trainData, CV_32F);
    }
    
    /*
    * 训练SVM
    */
    void svm_train(Ptr<SVM> &model, Mat &trainData, Mat &trainLabels)
    {
        model->setType(SVM::C_SVC);     //SVM类型
        model->setKernel(SVM::LINEAR);  //核函数,这里使用线性核
    
        Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);
    
        cout << "SVM: start train ..." << endl;
        model->trainAuto(tData);
        cout << "SVM: train success ..." << endl;
    }
    
    /*
    * 利用训练好的SVM预测
    */
    void svm_pridect(Ptr<SVM> &model, Mat test)
    {
        Mat result;
        float rst = model->predict(test, result);
        for (auto i = 0; i < result.rows; i++){
            cout << result.at<float>(i, 0);
        }
    }
    
    int main(int argc, const char** argv)
    {
        fstream io;
        io.open("test_list.txt", ios::in);
    
        string train_path = "train_list.txt";
            
        vector<Mat> test_set;
        get_test(io, test_set);
    
        Ptr<SVM> model = SVM::create();
        Mat trainData, trainLabels;
        get_data(train_path, trainData, trainLabels);
        svm_train(model, trainData, trainLabels);
    
        Mat testData;
        transform(test_set, testData);
        svm_pridect(model, testData);
    }
    

    trian_list.txt文件格式是这样的:

    D:ImgProProjectforcharcodeeta00	rain_data-1.jpg		0
    D:ImgProProjectforcharcodeeta00	rain_data-2.jpg		0
    

    每行前一段表示训练图片地址,最后的数字表示这个图片对应标签
    test_list.txt中格式与train_list.txt差不多,只是没有了标签。

    作者:Brccq
    出处:http://www.cnblogs.com/br170525// 出处:http://www.loveyfyq.com//
    欢迎转载,必须在文章页面明显位置给出原文连接,如需本博文源代码或者有任何问题,请在博文留下您的邮箱或者问题说明。

  • 相关阅读:
    Java多线程性能优化
    It is indirectly referenced from required .class files
    Switch基本知识
    HibernateTemplate 查询
    Hibernate工作原理及为什么要用?
    深入Java集合学习系列:HashMap的实现原理
    sql查询语句中的乱码 -- 前面加N
    Windows 8.1内置微软五笔输入法
    the rendering library is more recent than your version of android studio
    JBoss vs. Tomcat
  • 原文地址:https://www.cnblogs.com/br170525/p/9236479.html
Copyright © 2011-2022 走看看