zoukankan      html  css  js  c++  java
  • 在opencv3中利用SVM进行图像目标检测和分类

    采用鼠标事件,手动选择样本点,包括目标样本和背景样本。组成训练数据进行训练

    1、主函数

    #include "stdafx.h"
    #include "opencv2/opencv.hpp"
    using namespace cv;
    using namespace cv::ml;
    
    Mat img,image;
    Mat targetData, backData;
    bool flag = true;
    string wdname = "image";
    
    void on_mouse(int event, int x, int y, int flags, void* ustc); //鼠标取样本点
    void getTrainData(Mat &train_data, Mat &train_label);  //生成训练数据 
    void svm(); //svm分类
    
    
    int main(int argc, char** argv)
    {
        string path = "d:/peppers.png";
        img = imread(path);
        img.copyTo(image);
        if (img.empty())
        {
            cout << "Image load error";
            return 0;
        }
        namedWindow(wdname);
        setMouseCallback(wdname, on_mouse, 0);
    
        for (;;)
        {
            imshow("image", img);
    
            int c = waitKey(0);
            if ((c & 255) == 27)
            {
                cout << "Exiting ...
    ";
                break;
            }
            if ((char)c == 'c')
            {
                flag = false;
            }
            if ((char)c == 'q')
            {
                destroyAllWindows();
                break;
            }
        }
        svm();
        return 0;
    }

    首先输入图像,调用setMouseCallback函数进行鼠标取点

    2、鼠标事件

    //鼠标在图像上取样本点,按q键退出
    void on_mouse(int event, int x, int y, int flags, void* ustc)
    {
        if (event == CV_EVENT_LBUTTONDOWN)
        {
            Point pt = Point(x, y);
            Vec3b point = img.at<Vec3b>(y, x);  //取出该坐标处的像素值,注意x,y的顺序
            Mat tmp = (Mat_<float>(1, 3) << point[0], point[1], point[2]);
            if (flag)
            {
                targetData.push_back(tmp); //加入正样本矩阵
                circle(img, pt, 2, Scalar(0, 255, 255), -1, 8); //画圆,在图上显示点击的点 
    
            }
    
            else
            {
                backData.push_back(tmp); //加入负样本矩阵
                circle(img, pt, 2, Scalar(255, 0, 0), -1, 8); 
    
            }
            imshow(wdname, img);
        }
    }

    用鼠标在图像上点击,取出当前点的红绿蓝像素值进行训练。先选择任意个目标样本,然后按"c“键后选择任意个背景样本。样本数可以自己随意决定。样本选择完后,按”q"键完成样本选择。

    3、svm分类

    void getTrainData(Mat &train_data, Mat &train_label)
    {
        int m = targetData.rows;
        int n = backData.rows;
        cout << "正样本数::" << m << endl;
        cout << "负样本数:" << n << endl;
        vconcat(targetData, backData, train_data); //合并所有的样本点,作为训练数据
        train_label = Mat(m + n, 1, CV_32S, Scalar::all(1)); //初始化标注
        for (int i = m; i < m + n; i++)
            train_label.at<int>(i, 0) = -1;
    }
    
    void svm()
    {
        Mat train_data, train_label;
        getTrainData(train_data, train_label); //获取鼠标选择的样本训练数据
    
        // 设置参数
        Ptr<SVM> svm = SVM::create();
        svm->setType(SVM::C_SVC);
        svm->setKernel(SVM::LINEAR);
    
        // 训练分类器
        Ptr<TrainData> tData = TrainData::create(train_data, ROW_SAMPLE, train_label);
        svm->train(tData);
    
        Vec3b color(0, 0, 0);
        // Show the decision regions given by the SVM
        for (int i = 0; i < image.rows; ++i)
        for (int j = 0; j < image.cols; ++j)
        {
            Vec3b point = img.at<Vec3b>(i, j);  //取出该坐标处的像素值
            Mat sampleMat = (Mat_<float>(1, 3) << point[0], point[1], point[2]);
            float response = svm->predict(sampleMat);  //进行预测,返回1或-1,返回类型为float
            if ((int)response != 1)
                image.at<Vec3b>(i, j) = color;  //将背景点设为黑色
        }
    
        imshow("SVM Simple Example", image); // show it to the user
        waitKey(0);
    }

    将正负样本矩阵,用vconcat合并成一个矩阵,用作训练分类器,并对相应的样本进行标注。最后将识别出的目标保留,将背景部分调成黑色。

    4、完整程序

    // svm.cpp : 定义控制台应用程序的入口点。
    //
    
    #include "stdafx.h"
    #include "opencv2/opencv.hpp"
    using namespace cv;
    using namespace cv::ml;
    
    Mat img,image;
    Mat targetData, backData;
    bool flag = true;
    string wdname = "image";
    
    void on_mouse(int event, int x, int y, int flags, void* ustc); //鼠标取样本点
    void getTrainData(Mat &train_data, Mat &train_label);  //生成训练数据 
    void svm(); //svm分类
    
    
    int main(int argc, char** argv)
    {
        string path = "d:/peppers.png";
        img = imread(path);
        img.copyTo(image);
        if (img.empty())
        {
            cout << "Image load error";
            return 0;
        }
        namedWindow(wdname);
        setMouseCallback(wdname, on_mouse, 0);
    
        for (;;)
        {
            imshow("image", img);
    
            int c = waitKey(0);
            if ((c & 255) == 27)
            {
                cout << "Exiting ...
    ";
                break;
            }
            if ((char)c == 'c')
            {
                flag = false;
            }
            if ((char)c == 'q')
            {
                destroyAllWindows();
                break;
            }
        }
        svm();
        return 0;
    }
    
    //鼠标在图像上取样本点,按q键退出
    void on_mouse(int event, int x, int y, int flags, void* ustc)
    {
        if (event == CV_EVENT_LBUTTONDOWN)
        {
            Point pt = Point(x, y);
            Vec3b point = img.at<Vec3b>(y, x);  //取出该坐标处的像素值,注意x,y的顺序
            Mat tmp = (Mat_<float>(1, 3) << point[0], point[1], point[2]);
            if (flag)
            {
                targetData.push_back(tmp); //加入正样本矩阵
                circle(img, pt, 2, Scalar(0, 255, 255), -1, 8); //画出点击的点 
    
            }
    
            else
            {
                backData.push_back(tmp); //加入负样本矩阵
                circle(img, pt, 2, Scalar(255, 0, 0), -1, 8); 
    
            }
            imshow(wdname, img);
        }
    }
    
    void getTrainData(Mat &train_data, Mat &train_label)
    {
        int m = targetData.rows;
        int n = backData.rows;
        cout << "正样本数::" << m << endl;
        cout << "负样本数:" << n << endl;
        vconcat(targetData, backData, train_data); //合并所有的样本点,作为训练数据
        train_label = Mat(m + n, 1, CV_32S, Scalar::all(1)); //初始化标注
        for (int i = m; i < m + n; i++)
            train_label.at<int>(i, 0) = -1;
    }
    
    void svm()
    {
        Mat train_data, train_label;
        getTrainData(train_data, train_label); //获取鼠标选择的样本训练数据
    
        // 设置参数
        Ptr<SVM> svm = SVM::create();
        svm->setType(SVM::C_SVC);
        svm->setKernel(SVM::LINEAR);
    
        // 训练分类器
        Ptr<TrainData> tData = TrainData::create(train_data, ROW_SAMPLE, train_label);
        svm->train(tData);
    
        Vec3b color(0, 0, 0);
        // Show the decision regions given by the SVM
        for (int i = 0; i < image.rows; ++i)
        for (int j = 0; j < image.cols; ++j)
        {
            Vec3b point = img.at<Vec3b>(i, j);  //取出该坐标处的像素值
            Mat sampleMat = (Mat_<float>(1, 3) << point[0], point[1], point[2]);
            float response = svm->predict(sampleMat);  //进行预测,返回1或-1,返回类型为float
            if ((int)response != 1)
                image.at<Vec3b>(i, j) = color;  //将背景设置为黑色
        }
    
        imshow("SVM Simple Example", image); 
        waitKey(0);
    }

    输入原图像:

    程序运行后显示:

  • 相关阅读:
    FileUpload的使用
    关于hibernate4的配置我要好好反省一下
    比较SQL Server 2000 数据库中两个库的差异
    用google生活
    用OWC11图形分析本页面及其他页面Table中的数据
    请教ASP.NET培训应该培训的内容和以及顺序
    最近一个快要结束的项目的BUG分析
    我也发软件开发团队的思考(侧重点是人员)
    一个SQL语句的问题,我百思不得其解,请教各位
    分享C#高端视频教程
  • 原文地址:https://www.cnblogs.com/denny402/p/5020551.html
Copyright © 2011-2022 走看看