注意:数据结构的一致性,在高维度数据一般使用rbf核函数,使用网格搜索思想迭代求出gamma和c。
每行为一个样本,数据类型都围绕标黄代码而定义的。
SVM训练如下坐标(左边一列为A类,右边为B类),然后预测给出的坐标属于哪一类。
#include<opencv2opencv.hpp> #include<iostream> #include<opencv2ml.hpp> //引入机器学习 using namespace cv; using namespace std; using namespace ml; int main() { //*1、类别标签labelsMat,因为其是短整型,所以labels定义成int类型。最后再转回char int labels[14] = { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B' }; Mat labelsMat(14, 1, CV_32S);//短整型 for (int i = 0; i < labelsMat.rows; i++) { labelsMat.at<int>(i, 0) = labels[i]; } //*2、用于训练的样本集trainingDataMat int trainingData[14][2] = { { 110, 204 }, { 105, 306 }, { 102, 410 }, { 99, 511 }, { 93, 610 }, { 89, 713 }, { 89, 817 }, { 173, 208 }, { 175, 313 }, { 167, 415 }, { 163, 514 }, { 160, 612 }, { 156, 716 }, { 152, 819 } }; Mat trainingDataMat(14, 2, CV_32F); //float类型 for (int i = 0; i < trainingDataMat.rows; i++) { for (int j = 0; j < trainingDataMat.cols; j++) { trainingDataMat.at<float>(i, j) = trainingData[i][j]; } } //*3、初始化SVM,参数参考 https://blog.csdn.net/qq_27278957/article/details/88736516 Ptr<ml::SVM> svm = ml::SVM::create(); svm->setType(SVM::C_SVC); //svm的类型, svm->setKernel(SVM::LINEAR); //核函数 svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, FLT_EPSILON)); //终止条件 //*4、训练模型 Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);//训练样本的数据类型必须是CV_32F,标签可以是CV_32S或其他。 svm->train(tData); svm->save("svmData.xml"); //*5、预测 Mat tmp(1, 2, CV_32F); tmp.at<float>(0, 0) = 163; tmp.at<float>(0, 1) = 600; char label = (char)svm->predict(tmp); //ASCII码转字符,预测结果为B cout << label << endl; waitKey(0); return 0; }
上图绘制代码:
Mat plot(900, 900, CV_8U); vector<Point> myPoint(14);//14个点 for (int i = 0; i < myPoint.size(); i++) { myPoint[i].x = trainingData[i][0]; myPoint[i].y = trainingData[i][1]; circle(plot, myPoint[i], 15, Scalar(255), -1); } namedWindow("坐标点", 0); imshow("坐标点", plot);
【参考】