1 #include <opencv2/core.hpp> 2 #include <opencv2/imgproc.hpp> 3 #include "opencv2/imgcodecs.hpp" 4 #include <opencv2/highgui.hpp> 5 #include <opencv2/ml.hpp> 6 7 using namespace cv; 8 using namespace cv::ml; 9 10 #include<iostream> 11 using namespace std; 12 13 14 void PrintMat(const Mat& mat) 15 { 16 for (int j = 0; j < mat.rows; j++) 17 { 18 for (int i = 0; i < mat.cols; i++) 19 { 20 cout << mat.ptr<float>(j)[i] << " "; 21 } 22 cout << endl; 23 } 24 cout << endl; 25 } 26 27 int main(int, char**) 28 { 29 // Data for visual representation 30 int width = 512, height = 512; 31 Mat image = Mat::zeros(height, width, CV_8UC3); 32 33 //<1>训练数据 34 int labels[3] = { 1, 1, -1 }; 35 float trainingData[3][2] = { { 30, 30 },{ 40, 40 },{10, 10 }}; 36 Mat trainingDataMat(3, 2, CV_32FC1, trainingData); 37 Mat labelsMat(3, 1, CV_32SC1, labels); 38 39 //SVM初始化并开始训练 40 Ptr<SVM> svm = SVM::create(); 41 svm->setType(SVM::C_SVC); 42 svm->setKernel(SVM::LINEAR); 43 svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6)); 44 svm->train(trainingDataMat, ROW_SAMPLE, labelsMat); 45 46 //预测 验证样本 47 Vec3b green(0, 255, 0), blue(255, 0, 0); 48 for (int i = 0; i < image.rows; ++i) 49 for (int j = 0; j < image.cols; ++j) 50 { 51 Mat sampleMat = (Mat_<float>(1, 2) << j, i); 52 float response = svm->predict(sampleMat); 53 54 if (response == 1) 55 image.at<Vec3b>(i, j) = green; 56 else if (response == -1) 57 image.at<Vec3b>(i, j) = blue; 58 } 59 60 // 显示训练数据 61 int thickness = -1; 62 int lineType = 8; 63 circle(image, Point(30, 30), 5, Scalar(0, 0, 0), thickness, lineType); 64 circle(image, Point(40, 40), 5, Scalar(0, 0, 0), thickness, lineType); 65 circle(image, Point(10, 10), 5, Scalar(255, 255, 255), thickness, lineType); 66 67 //显示支撑向量 68 thickness = 2; 69 lineType = 8; 70 71 Mat sv = svm->getUncompressedSupportVectors(); 72 PrintMat(sv); 73 Mat sv_ = svm->getSupportVectors(); 74 PrintMat(sv_); 75 76 for (int i = 0; i < sv.rows; ++i) 77 { 78 const float* v = sv.ptr<float>(i); 79 circle(image, Point((int)v[0], (int)v[1]), 6, Scalar(0, 0, 255), thickness, lineType); 80 } 81 // imwrite("result.png", image); // save the image 82 imshow("SVM Simple Example", image); // show it to the user 83 waitKey(0); 84 85 }
如图,白色点是正样本、黑色点是负样本。红色点表示支撑向量上的点。绿色和蓝色是svm->predict的结果。
如图:Mat sv = svm->getUncompressedSupportVectors(); 得出支撑向量上的样本点,打印出来就是(30 ,30) (10, 10)
Mat sv_ = svm->getSupportVectors(); 得出决策面方程,和我们上述计算的一致(这里点都放大了10倍),不知为什么没有超平面的截距
参考:https://blog.csdn.net/chaipp0607/article/details/73662441