zoukankan      html  css  js  c++  java
  • SVM《一、硬分类、目标函数推导、二分类实验》

     1 #include <opencv2/core.hpp>
     2 #include <opencv2/imgproc.hpp>
     3 #include "opencv2/imgcodecs.hpp"
     4 #include <opencv2/highgui.hpp>
     5 #include <opencv2/ml.hpp>
     6 
     7 using namespace cv;
     8 using namespace cv::ml;
     9 
    10 #include<iostream>
    11 using namespace std;
    12 
    13 
    14 void PrintMat(const Mat& mat)
    15 {
    16     for (int j = 0; j < mat.rows; j++)
    17     {
    18         for (int i = 0; i < mat.cols; i++)
    19         {
    20             cout << mat.ptr<float>(j)[i] << " ";
    21         }
    22         cout << endl;
    23     }
    24     cout << endl;
    25 }
    26 
    27 int main(int, char**)
    28 {
    29     // Data for visual representation
    30     int width = 512, height = 512;
    31     Mat image = Mat::zeros(height, width, CV_8UC3);
    32 
    33     //<1>训练数据
    34     int labels[3] = { 1, 1, -1 };
    35     float trainingData[3][2] = { { 30, 30 },{ 40, 40 },{10, 10 }};
    36     Mat trainingDataMat(3, 2, CV_32FC1, trainingData);
    37     Mat labelsMat(3, 1, CV_32SC1, labels);
    38 
    39     //SVM初始化并开始训练
    40     Ptr<SVM> svm = SVM::create();
    41     svm->setType(SVM::C_SVC);
    42     svm->setKernel(SVM::LINEAR);
    43     svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
    44     svm->train(trainingDataMat, ROW_SAMPLE, labelsMat);
    45 
    46     //预测 验证样本
    47     Vec3b green(0, 255, 0), blue(255, 0, 0);
    48     for (int i = 0; i < image.rows; ++i)
    49         for (int j = 0; j < image.cols; ++j)
    50         {
    51             Mat sampleMat = (Mat_<float>(1, 2) << j, i);
    52             float response = svm->predict(sampleMat);
    53 
    54             if (response == 1)
    55                 image.at<Vec3b>(i, j) = green;
    56             else if (response == -1)
    57                 image.at<Vec3b>(i, j) = blue;
    58         }
    59 
    60     // 显示训练数据
    61     int thickness = -1;
    62     int lineType = 8;
    63     circle(image, Point(30, 30), 5, Scalar(0, 0, 0), thickness, lineType);
    64     circle(image, Point(40, 40), 5, Scalar(0, 0, 0), thickness, lineType);
    65     circle(image, Point(10, 10), 5, Scalar(255, 255, 255), thickness, lineType);
    66 
    67     //显示支撑向量
    68     thickness = 2;
    69     lineType = 8;
    70 
    71     Mat sv = svm->getUncompressedSupportVectors();
    72     PrintMat(sv);
    73     Mat sv_ = svm->getSupportVectors();
    74     PrintMat(sv_);
    75 
    76     for (int i = 0; i < sv.rows; ++i)
    77     {
    78         const float* v = sv.ptr<float>(i);
    79         circle(image, Point((int)v[0], (int)v[1]), 6, Scalar(0, 0, 255), thickness, lineType);
    80     }
    81     // imwrite("result.png", image);        // save the image
    82     imshow("SVM Simple Example", image); // show it to the user
    83     waitKey(0);
    84 
    85 }

    如图,白色点是正样本、黑色点是负样本。红色点表示支撑向量上的点。绿色和蓝色是svm->predict的结果。

    如图:Mat sv = svm->getUncompressedSupportVectors(); 得出支撑向量上的样本点,打印出来就是(30 ,30) (10, 10)

    Mat sv_ = svm->getSupportVectors(); 得出决策面方程,和我们上述计算的一致(这里点都放大了10倍),不知为什么没有超平面的截距

     参考:https://blog.csdn.net/chaipp0607/article/details/73662441

  • 相关阅读:
    graphite custom functions
    falcon适配ldap密码同步
    dell 远程管理卡的使用racadm
    mac 入门
    使用 kafkat 在线扩缩容 kafka replicas
    python收集jvm数据
    kafka java.rmi.server.ExportException: Port already in use
    centos6安装最新syslog-ng推送hdfs
    从 falcon api 中获取数据
    fluentd 推送 mariadb audit log
  • 原文地址:https://www.cnblogs.com/winslam/p/10183124.html
Copyright © 2011-2022 走看看