zoukankan      html  css  js  c++  java
  • SVM的使用train()

    注意:数据结构的一致性,在高维度数据一般使用rbf核函数,使用网格搜索思想迭代求出gamma和c。

    每行为一个样本,数据类型都围绕标黄代码而定义的。

    SVM训练如下坐标(左边一列为A类,右边为B类),然后预测给出的坐标属于哪一类。

    #include<opencv2opencv.hpp>
    #include<iostream>
    #include<opencv2ml.hpp> //引入机器学习
    using namespace cv;
    using namespace std;
    using namespace ml;
    
    int main()
    {
        //*1、类别标签labelsMat,因为其是短整型,所以labels定义成int类型。最后再转回char
        int labels[14] = { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B' };
        Mat labelsMat(14, 1, CV_32S);//短整型
        for (int i = 0; i < labelsMat.rows; i++)
        {
            labelsMat.at<int>(i, 0) = labels[i];
        }
        //*2、用于训练的样本集trainingDataMat
        int trainingData[14][2] = { { 110, 204 }, { 105, 306 }, { 102, 410 }, { 99, 511 }, { 93, 610 }, { 89, 713 }, { 89, 817 },
        { 173, 208 }, { 175, 313 }, { 167, 415 }, { 163, 514 }, { 160, 612 }, { 156, 716 }, { 152, 819 } };
        Mat trainingDataMat(14, 2, CV_32F); //float类型
        for (int i = 0; i < trainingDataMat.rows; i++)
        {
            for (int j = 0; j < trainingDataMat.cols; j++)
            {
                trainingDataMat.at<float>(i, j) = trainingData[i][j];
            }
        }
        //*3、初始化SVM,参数参考 https://blog.csdn.net/qq_27278957/article/details/88736516
        Ptr<ml::SVM> svm = ml::SVM::create();
        svm->setType(SVM::C_SVC); //svm的类型,
        svm->setKernel(SVM::LINEAR); //核函数
        svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, FLT_EPSILON)); //终止条件
        //*4、训练模型
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);//训练样本的数据类型必须是CV_32F,标签可以是CV_32S或其他。
        svm->train(tData);
        svm->save("svmData.xml");
        //*5、预测
        Mat tmp(1, 2, CV_32F);
        tmp.at<float>(0, 0) = 163;
        tmp.at<float>(0, 1) = 600;
    
        char label = (char)svm->predict(tmp); //ASCII码转字符,预测结果为B
        cout << label << endl;
    
        waitKey(0);
        return 0;
    }

    上图绘制代码:

    Mat plot(900, 900, CV_8U);
    vector<Point> myPoint(14);//14个点
    for (int i = 0; i < myPoint.size(); i++)
    {
        myPoint[i].x = trainingData[i][0];
        myPoint[i].y = trainingData[i][1];
        circle(plot, myPoint[i], 15, Scalar(255), -1);
    }
    namedWindow("坐标点", 0);
    imshow("坐标点", plot);

     【参考】

    https://blog.csdn.net/bigFatCat_Tom/article/details/95201903?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

  • 相关阅读:
    Python操作RabbitMQ
    数组的排序算法
    元类
    Python 中的单例模式
    JS的Ajax和同源策略
    Ajax
    Linux目录结构以及文件操作
    Pymysql
    struts2拦截器和过滤器区别
    为Github 托管项目的访问添加SSH keys
  • 原文地址:https://www.cnblogs.com/xixixing/p/12376324.html
Copyright © 2011-2022 走看看