zoukankan      html  css  js  c++  java
  • [code segments] OpenCV3.0 SVM with C++ interface

    talk is cheap, show you the code:

    /************************************************************************/
    /* Name   : OpenCV SVM test                                             */
    /* Date   : 2015/11/7                                                   */
    /* Author : aban                                                        */
    /************************************************************************/
    // note : the code is modified from internet. 
    
    #include <iostream>
    #include <cmath>
    #include <string>
    using namespace std;
    
    #include <opencv2/opencv.hpp>
    #include <opencv2/ml.hpp>
    using namespace cv;
    
    bool plotSupportVectors = true;
    int numTrainingPoints = 200;
    int numTestPoints = 2000;
    int size = 200;
    int eq = 0;
    
    // accuracy
    float evaluate(cv::Mat& predicted, cv::Mat& actual) {
      assert(predicted.rows == actual.rows);
      int t = 0;
      int f = 0;
      for (int i = 0; i < actual.rows; i++) {
        float p = predicted.at<float>(i, 0);
        float a = actual.at<float>(i, 0);
        if ((p >= 0.0 && a >= 0.0) || (p <= 0.0 &&  a <= 0.0)) {
          t++;
        }
        else {
          f++;
        }
      }
      return (t * 1.0) / (t + f);
    }
    
    // plot data and class
    void plot_binary(cv::Mat& data, cv::Mat& classes, string name) {
      cv::Mat plot(size, size, CV_8UC3);
      plot.setTo(cv::Scalar(255.0, 255.0, 255.0));
      for (int i = 0; i < data.rows; i++) {
    
        float x = data.at<float>(i, 0) * size;
        float y = data.at<float>(i, 1) * size;
    
        if (classes.at<float>(i, 0) > 0) {
          cv::circle(plot, Point(x, y), 2, CV_RGB(255, 0, 0), 1);
        }
        else {
          cv::circle(plot, Point(x, y), 2, CV_RGB(0, 255, 0), 1);
        }
      }
      cv::namedWindow(name, CV_WINDOW_KEEPRATIO);
      cv::imshow(name, plot);
    }
    
    // function to learn
    int f(float x, float y, int equation) {
      switch (equation) {
      case 0:
        return y > sin(x * 10) ?

    -1 : 1; break; case 1: return y > cos(x * 10) ? -1 : 1; break; case 2: return y > 2 * x ?

    -1 : 1; break; case 3: return y > tan(x * 10) ?

    -1 : 1; break; default: return y > cos(x * 10) ?

    -1 : 1; } } // label data with equation cv::Mat labelData(cv::Mat points, int equation) { cv::Mat labels(points.rows, 1, CV_32FC1); for (int i = 0; i < points.rows; i++) { float x = points.at<float>(i, 0); float y = points.at<float>(i, 1); labels.at<float>(i, 0) = f(x, y, equation); } return labels; } void svm(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses) { Mat traning_label(trainingClasses.rows, 1, CV_32SC1); for (int i = 0; i < trainingClasses.rows; i++){ traning_label.at<int>(i, 0) = trainingClasses.at<float>(i, 0); } cv::Ptr<cv::ml::SVM> svm = ml::SVM::create(); svm->setType(ml::SVM::Types::C_SVC); svm->setKernel(ml::SVM::KernelTypes::RBF); //svm->setDegree(0); // for poly svm->setGamma(20); // for poly/rbf/sigmoid //svm->setCoef0(0); // for poly/sigmoid svm->setC(7); // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR //svm->setNu(0); // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR //svm->setP(0); // for CV_SVM_EPS_SVR svm->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 1000, 1E-6)); svm->train(trainingData, ml::SampleTypes::ROW_SAMPLE, traning_label); cv::Mat predicted(testClasses.rows, 1, CV_32F); svm->predict(testData, predicted); cout << "Accuracy_{SVM} = " << evaluate(predicted, testClasses) << endl; plot_binary(testData, predicted, "Predictions SVM"); // plot support vectors if (plotSupportVectors) { cv::Mat plot_sv(size, size, CV_8UC3); plot_sv.setTo(cv::Scalar(255.0, 255.0, 255.0)); Mat support_vectors = svm->getSupportVectors(); for (int vecNum = 0; vecNum < support_vectors.rows; vecNum++){ cv::circle(plot_sv, Point(support_vectors.row(vecNum).at<float>(0)*size, support_vectors.row(vecNum).at<float>(1)*size), 3, CV_RGB(0, 0, 0)); } namedWindow("Support Vectors", CV_WINDOW_KEEPRATIO); cv::imshow("Support Vectors", plot_sv); } } int main(){ cv::Mat trainingData(numTrainingPoints, 2, CV_32FC1); cv::Mat testData(numTestPoints, 2, CV_32FC1); cv::randu(trainingData, 0, 1); cv::randu(testData, 0, 1); cv::Mat trainingClasses = labelData(trainingData, eq); cv::Mat testClasses = labelData(testData, eq); plot_binary(trainingData, trainingClasses, "Training Data"); plot_binary(testData, testClasses, "Test Data"); svm(trainingData, trainingClasses, testData, testClasses); waitKey(0); return 0; }

  • 相关阅读:
    curd_4
    curd_2
    Python Regex库的使用
    Python Assert 确认条件为真的工具
    Python Regex库的使用(2)
    Python lambda的用法
    Python 列表综合
    with ss(date,date2) (select * from sysdummy1) select * from ss
    延迟执行函数
    ObjectiveC 的基本数据类型、数字、字符串和集合等介绍
  • 原文地址:https://www.cnblogs.com/yfceshi/p/7044264.html
Copyright © 2011-2022 走看看