zoukankan      html  css  js  c++  java
  • OpenCV官方例程引导与赏识-2

      在安装好的OpenCV的文件夹下,有相关的文件。具体位置看各人的安装路径,但大体上路径如下:***opencvsourcessamplescpp。

      如“彩色目标跟踪”:Camshift

      “光流”:optical flow

      “点跟踪”:lkdemo

      “人脸识别”:objectDetection

      “支持向量机引导”:CvSVM::train

      

      在上一级目录中可以发现,除了CPP之外,还有其他语言的,如java,Python等。

      

     2.1  camshiftdemo

    #include "opencv2/core/utility.hpp"
    #include "opencv2/video/tracking.hpp"
    #include "opencv2/imgproc.hpp"
    #include "opencv2/videoio.hpp"
    #include "opencv2/highgui.hpp"
    
    #include <iostream>
    #include <ctype.h>
    
    using namespace cv;
    using namespace std;
    
    Mat image;
    
    bool backprojMode = false;
    bool selectObject = false;
    int trackObject = 0;
    bool showHist = true;
    Point origin;
    Rect selection;
    int vmin = 10, vmax = 256, smin = 30;
    
    // User draws box around object to track. This triggers CAMShift to start tracking
    static void onMouse(int event, int x, int y, int, void*)
    {
        if (selectObject)
        {
            selection.x = MIN(x, origin.x);
            selection.y = MIN(y, origin.y);
            selection.width = std::abs(x - origin.x);
            selection.height = std::abs(y - origin.y);
    
            selection &= Rect(0, 0, image.cols, image.rows);
        }
    
        switch (event)
        {
        case EVENT_LBUTTONDOWN:
            origin = Point(x, y);
            selection = Rect(x, y, 0, 0);
            selectObject = true;
            break;
        case EVENT_LBUTTONUP:
            selectObject = false;
            if (selection.width > 0 && selection.height > 0)
                trackObject = -1;   // Set up CAMShift properties in main() loop
            break;
        }
    }
    
    string hot_keys =
    "
    
    Hot keys: 
    "
    "	ESC - quit the program
    "
    "	c - stop the tracking
    "
    "	b - switch to/from backprojection view
    "
    "	h - show/hide object histogram
    "
    "	p - pause video
    "
    "To initialize tracking, select the object with mouse
    ";
    
    static void help(const char** argv)
    {
        cout << "
    This is a demo that shows mean-shift based tracking
    "
            "You select a color objects such as your face and it tracks it.
    "
            "This reads from video camera (0 by default, or the camera number the user enters
    "
            "Usage: 
    	";
        cout << argv[0] << " [camera number]
    ";
        cout << hot_keys;
    }
    
    const char* keys =
    {
        "{help h | | show help message}{@camera_number| 0 | camera number}"
    };
    
    int main(int argc, const char** argv)
    {
        VideoCapture cap;
        Rect trackWindow;
        int hsize = 16;
        float hranges[] = { 0,180 };
        const float* phranges = hranges;
        CommandLineParser parser(argc, argv, keys);
        if (parser.has("help"))
        {
            help(argv);
            return 0;
        }
        int camNum = parser.get<int>(0);
        cap.open(camNum);
    
        if (!cap.isOpened())
        {
            help(argv);
            cout << "***Could not initialize capturing...***
    ";
            cout << "Current parameter's value: 
    ";
            parser.printMessage();
            return -1;
        }
        cout << hot_keys;
        namedWindow("Histogram", 0);
        namedWindow("CamShift Demo", 0);
        setMouseCallback("CamShift Demo", onMouse, 0);
        createTrackbar("Vmin", "CamShift Demo", &vmin, 256, 0);
        createTrackbar("Vmax", "CamShift Demo", &vmax, 256, 0);
        createTrackbar("Smin", "CamShift Demo", &smin, 256, 0);
    
        Mat frame, hsv, hue, mask, hist, histimg = Mat::zeros(200, 320, CV_8UC3), backproj;
        bool paused = false;
    
        for (;;)
        {
            if (!paused)
            {
                cap >> frame;
                if (frame.empty())
                    break;
            }
    
            frame.copyTo(image);
    
            if (!paused)
            {
                cvtColor(image, hsv, COLOR_BGR2HSV);
    
                if (trackObject)
                {
                    int _vmin = vmin, _vmax = vmax;
    
                    inRange(hsv, Scalar(0, smin, MIN(_vmin, _vmax)),
                        Scalar(180, 256, MAX(_vmin, _vmax)), mask);
                    int ch[] = { 0, 0 };
                    hue.create(hsv.size(), hsv.depth());
                    mixChannels(&hsv, 1, &hue, 1, ch, 1);
    
                    if (trackObject < 0)
                    {
                        // Object has been selected by user, set up CAMShift search properties once
                        Mat roi(hue, selection), maskroi(mask, selection);
                        calcHist(&roi, 1, 0, maskroi, hist, 1, &hsize, &phranges);
                        normalize(hist, hist, 0, 255, NORM_MINMAX);
    
                        trackWindow = selection;
                        trackObject = 1; // Don't set up again, unless user selects new ROI
    
                        histimg = Scalar::all(0);
                        int binW = histimg.cols / hsize;
                        Mat buf(1, hsize, CV_8UC3);
                        for (int i = 0; i < hsize; i++)
                            buf.at<Vec3b>(i) = Vec3b(saturate_cast<uchar>(i * 180. / hsize), 255, 255);
                        cvtColor(buf, buf, COLOR_HSV2BGR);
    
                        for (int i = 0; i < hsize; i++)
                        {
                            int val = saturate_cast<int>(hist.at<float>(i) * histimg.rows / 255);
                            rectangle(histimg, Point(i * binW, histimg.rows),
                                Point((i + 1) * binW, histimg.rows - val),
                                Scalar(buf.at<Vec3b>(i)), -1, 8);
                        }
                    }
    
                    // Perform CAMShift
                    calcBackProject(&hue, 1, 0, hist, backproj, &phranges);
                    backproj &= mask;
                    RotatedRect trackBox = CamShift(backproj, trackWindow,
                        TermCriteria(TermCriteria::EPS | TermCriteria::COUNT, 10, 1));
                    if (trackWindow.area() <= 1)
                    {
                        int cols = backproj.cols, rows = backproj.rows, r = (MIN(cols, rows) + 5) / 6;
                        trackWindow = Rect(trackWindow.x - r, trackWindow.y - r,
                            trackWindow.x + r, trackWindow.y + r) &
                            Rect(0, 0, cols, rows);
                    }
    
                    if (backprojMode)
                        cvtColor(backproj, image, COLOR_GRAY2BGR);
                    ellipse(image, trackBox, Scalar(0, 0, 255), 3, LINE_AA);
                }
            }
            else if (trackObject < 0)
                paused = false;
    
            if (selectObject && selection.width > 0 && selection.height > 0)
            {
                Mat roi(image, selection);
                bitwise_not(roi, roi);
            }
    
            imshow("CamShift Demo", image);
            imshow("Histogram", histimg);
    
            char c = (char)waitKey(10);
            if (c == 27)
                break;
            switch (c)
            {
            case 'b':
                backprojMode = !backprojMode;
                break;
            case 'c':
                trackObject = 0;
                histimg = Scalar::all(0);
                break;
            case 'h':
                showHist = !showHist;
                if (!showHist)
                    destroyWindow("Histogram");
                else
                    namedWindow("Histogram", 1);
                break;
            case 'p':
                paused = !paused;
                break;
            default:
                ;
            }
        }
    
        return 0;
    }
    camshift

    运行示例:

     1.2  opticalFlow

    //---------------------------------【头文件、命名空间包含部分】----------------------------
    //        描述:包含程序所使用的头文件和命名空间
    //-------------------------------------------------------------------------------------------------
    #include <opencv2/video/video.hpp>
    #include <opencv2/highgui/highgui.hpp>
    #include <opencv2/imgproc/imgproc.hpp>
    #include <opencv2/core/core.hpp>
    #include <iostream>
    #include <cstdio>
    
    using namespace std;
    using namespace cv;
    
    
    
    
    
    //-----------------------------------【全局函数声明】-----------------------------------------
    //        描述:声明全局函数
    //-------------------------------------------------------------------------------------------------
    void tracking(Mat& frame, Mat& output);
    bool addNewPoints();
    bool acceptTrackedPoint(int i);
    
    //-----------------------------------【全局变量声明】-----------------------------------------
    //        描述:声明全局变量
    //-------------------------------------------------------------------------------------------------
    string window_name = "optical flow tracking";
    Mat gray;    // 当前图片
    Mat gray_prev;    // 预测图片
    vector<Point2f> points[2];    // point0为特征点的原来位置,point1为特征点的新位置
    vector<Point2f> initial;    // 初始化跟踪点的位置
    vector<Point2f> features;    // 检测的特征
    int maxCount = 500;    // 检测的最大特征数
    double qLevel = 0.01;    // 特征检测的等级
    double minDist = 10.0;    // 两特征点之间的最小距离
    vector<uchar> status;    // 跟踪特征的状态,特征的流发现为1,否则为0
    vector<float> err;
    
    
    //-----------------------------------【main( )函数】--------------------------------------------
    //        描述:控制台应用程序的入口函数,我们的程序从这里开始
    //-------------------------------------------------------------------------------------------------
    int main()
    {
    
        Mat frame;
        Mat result;
    
        VideoCapture capture("1.avi");
    
        if (capture.isOpened())    // 摄像头读取文件开关
        {
            while (true)
            {
                capture >> frame;
    
                if (!frame.empty())
                {
                    tracking(frame, result);
                }
                else
                {
                    printf(" --(!) No captured frame -- Break!");
                    break;
                }
    
                int c = waitKey(50);
                if ((char)c == 27)
                {
                    break;
                }
            }
        }
        return 0;
    }
    
    //-------------------------------------------------------------------------------------------------
    // function: tracking
    // brief: 跟踪
    // parameter: frame    输入的视频帧
    //              output 有跟踪结果的视频帧
    // return: void
    //-------------------------------------------------------------------------------------------------
    void tracking(Mat& frame, Mat& output)
    {
    
        //此句代码的OpenCV3版为:
        cvtColor(frame, gray, COLOR_BGR2GRAY);
        //此句代码的OpenCV2版为:
        //cvtColor(frame, gray, CV_BGR2GRAY);
    
        frame.copyTo(output);
    
        // 添加特征点
        if (addNewPoints())
        {
            goodFeaturesToTrack(gray, features, maxCount, qLevel, minDist);
            points[0].insert(points[0].end(), features.begin(), features.end());
            initial.insert(initial.end(), features.begin(), features.end());
        }
    
        if (gray_prev.empty())
        {
            gray.copyTo(gray_prev);
        }
        // l-k光流法运动估计
        calcOpticalFlowPyrLK(gray_prev, gray, points[0], points[1], status, err);
        // 去掉一些不好的特征点
        int k = 0;
        for (size_t i = 0; i < points[1].size(); i++)
        {
            if (acceptTrackedPoint(i))
            {
                initial[k] = initial[i];
                points[1][k++] = points[1][i];
            }
        }
        points[1].resize(k);
        initial.resize(k);
        // 显示特征点和运动轨迹
        for (size_t i = 0; i < points[1].size(); i++)
        {
            line(output, initial[i], points[1][i], Scalar(0, 0, 255));
            circle(output, points[1][i], 3, Scalar(0, 255, 0), -1);
        }
    
        // 把当前跟踪结果作为下一此参考
        swap(points[1], points[0]);
        swap(gray_prev, gray);
    
        imshow(window_name, output);
    }
    
    //-------------------------------------------------------------------------------------------------
    // function: addNewPoints
    // brief: 检测新点是否应该被添加
    // parameter:
    // return: 是否被添加标志
    //-------------------------------------------------------------------------------------------------
    bool addNewPoints()
    {
        return points[0].size() <= 10;
    }
    
    //-------------------------------------------------------------------------------------------------
    // function: acceptTrackedPoint
    // brief: 决定哪些跟踪点被接受
    // parameter:
    // return:
    //-------------------------------------------------------------------------------------------------
    bool acceptTrackedPoint(int i)
    {
        return status[i] && ((abs(points[0][i].x - points[1][i].x) + abs(points[0][i].y - points[1][i].y)) > 2);
    }
    opticalFlow

    运行示例

    2.3 lkdemo

    #include "opencv2/video/tracking.hpp"
    #include "opencv2/imgproc.hpp"
    #include "opencv2/videoio.hpp"
    #include "opencv2/highgui.hpp"
    
    #include <iostream>
    #include <ctype.h>
    
    using namespace cv;
    using namespace std;
    
    static void help()
    {
        // print a welcome message, and the OpenCV version
        cout << "
    This is a demo of Lukas-Kanade optical flow lkdemo(),
    "
            "Using OpenCV version " << CV_VERSION << endl;
        cout << "
    It uses camera by default, but you can provide a path to video as an argument.
    ";
        cout << "
    Hot keys: 
    "
            "	ESC - quit the program
    "
            "	r - auto-initialize tracking
    "
            "	c - delete all the points
    "
            "	n - switch the "night" mode on/off
    "
            "To add/remove a feature point click it
    " << endl;
    }
    
    Point2f point;
    bool addRemovePt = false;
    
    static void onMouse(int event, int x, int y, int /*flags*/, void* /*param*/)
    {
        if (event == EVENT_LBUTTONDOWN)
        {
            point = Point2f((float)x, (float)y);
            addRemovePt = true;
        }
    }
    
    int main(int argc, char** argv)
    {
        VideoCapture cap;
        TermCriteria termcrit(TermCriteria::COUNT | TermCriteria::EPS, 20, 0.03);
        Size subPixWinSize(10, 10), winSize(31, 31);
    
        const int MAX_COUNT = 500;
        bool needToInit = false;
        bool nightMode = false;
    
        help();
        cv::CommandLineParser parser(argc, argv, "{@input|0|}");
        string input = parser.get<string>("@input");
    
        if (input.size() == 1 && isdigit(input[0]))
            cap.open(input[0] - '0');
        else
            cap.open(input);
    
        if (!cap.isOpened())
        {
            cout << "Could not initialize capturing...
    ";
            return 0;
        }
    
        namedWindow("LK Demo", 1);
        setMouseCallback("LK Demo", onMouse, 0);
    
        Mat gray, prevGray, image, frame;
        vector<Point2f> points[2];
    
        for (;;)
        {
            cap >> frame;
            if (frame.empty())
                break;
    
            frame.copyTo(image);
            cvtColor(image, gray, COLOR_BGR2GRAY);
    
            if (nightMode)
                image = Scalar::all(0);
    
            if (needToInit)
            {
                // automatic initialization
                goodFeaturesToTrack(gray, points[1], MAX_COUNT, 0.01, 10, Mat(), 3, 3, 0, 0.04);
                cornerSubPix(gray, points[1], subPixWinSize, Size(-1, -1), termcrit);
                addRemovePt = false;
            }
            else if (!points[0].empty())
            {
                vector<uchar> status;
                vector<float> err;
                if (prevGray.empty())
                    gray.copyTo(prevGray);
                calcOpticalFlowPyrLK(prevGray, gray, points[0], points[1], status, err, winSize,
                    3, termcrit, 0, 0.001);
                size_t i, k;
                for (i = k = 0; i < points[1].size(); i++)
                {
                    if (addRemovePt)
                    {
                        if (norm(point - points[1][i]) <= 5)
                        {
                            addRemovePt = false;
                            continue;
                        }
                    }
    
                    if (!status[i])
                        continue;
    
                    points[1][k++] = points[1][i];
                    circle(image, points[1][i], 3, Scalar(0, 255, 0), -1, 8);
                }
                points[1].resize(k);
            }
    
            if (addRemovePt && points[1].size() < (size_t)MAX_COUNT)
            {
                vector<Point2f> tmp;
                tmp.push_back(point);
                cornerSubPix(gray, tmp, winSize, Size(-1, -1), termcrit);
                points[1].push_back(tmp[0]);
                addRemovePt = false;
            }
    
            needToInit = false;
            imshow("LK Demo", image);
    
            char c = (char)waitKey(10);
            if (c == 27)
                break;
            switch (c)
            {
            case 'r':
                needToInit = true;
                break;
            case 'c':
                points[0].clear();
                points[1].clear();
                break;
            case 'n':
                nightMode = !nightMode;
                break;
            }
    
            std::swap(points[1], points[0]);
            cv::swap(prevGray, gray);
        }
    
        return 0;
    }
    lkdemo

    运行示例

     2.4 ObjectDetection

     //---------------------------------【头文件、命名空间包含部分】----------------------------
     //        描述:包含程序所使用的头文件和命名空间
     //-------------------------------------------------------------------------------------------------
    #include "opencv2/objdetect/objdetect.hpp"
    #include "opencv2/highgui/highgui.hpp"
    #include "opencv2/imgproc/imgproc.hpp"
    
    #include <iostream>
    #include <stdio.h>
    
    using namespace std;
    using namespace cv;
    
    
    
    
    void detectAndDisplay(Mat frame);
    
    //--------------------------------【全局变量声明】----------------------------------------------
    //        描述:声明全局变量
    //-------------------------------------------------------------------------------------------------
    //注意,需要把"haarcascade_frontalface_alt.xml"和"haarcascade_eye_tree_eyeglasses.xml"这两个文件复制到工程路径下
    String face_cascade_name = "haarcascade_frontalface_alt.xml";
    String eyes_cascade_name = "haarcascade_eye_tree_eyeglasses.xml";
    CascadeClassifier face_cascade;
    CascadeClassifier eyes_cascade;
    string window_name = "Capture - Face detection";
    RNG rng(12345);
    
    
    //-----------------------------------【main( )函数】--------------------------------------------
    //        描述:控制台应用程序的入口函数,我们的程序从这里开始
    //-------------------------------------------------------------------------------------------------
    int main(void)
    {
        VideoCapture capture;
        Mat frame;
    
    
        //-- 1. 加载级联(cascades)
        if (!face_cascade.load(face_cascade_name)) { printf("--(!)Error loading
    "); return -1; };
        if (!eyes_cascade.load(eyes_cascade_name)) { printf("--(!)Error loading
    "); return -1; };
    
        //-- 2. 读取视频
        capture.open(0);
       
        if (capture.isOpened())
        {
            for (;;)
            {
                capture >> frame;
    
                //-- 3. 对当前帧使用分类器(Apply the classifier to the frame)
                if (!frame.empty())
                {
                    detectAndDisplay(frame);
                }
                else
                {
                    printf(" --(!) No captured frame -- Break!"); break;
                }
    
                int c = waitKey(10);
                if ((char)c == 'c') { break; }
    
            }
        }
        return 0;
    }
    
    
    void detectAndDisplay(Mat frame)
    {
        std::vector<Rect> faces;
        Mat frame_gray;
    
        cvtColor(frame, frame_gray, COLOR_BGR2GRAY);
        equalizeHist(frame_gray, frame_gray);
    
        //-- 人脸检测
        //此句代码的OpenCV2版为:
       //face_cascade.detectMultiScale( frame_gray, faces, 1.1, 2, 0|CV_HAAR_SCALE_IMAGE, Size(30, 30) );
        //此句代码的OpenCV3版为:
        face_cascade.detectMultiScale(frame_gray, faces, 1.1, 2, 0 | CASCADE_SCALE_IMAGE, Size(30, 30));
    
    
        for (size_t i = 0; i < faces.size(); i++)
        {
            Point center(faces[i].x + faces[i].width / 2, faces[i].y + faces[i].height / 2);
            ellipse(frame, center, Size(faces[i].width / 2, faces[i].height / 2), 0, 0, 360, Scalar(255, 0, 255), 2, 8, 0);
    
            Mat faceROI = frame_gray(faces[i]);
            std::vector<Rect> eyes;
    
            //-- 在脸中检测眼睛
            //此句代码的OpenCV2版为:
           // eyes_cascade.detectMultiScale( faceROI, eyes, 1.1, 2, 0 |CV_HAAR_SCALE_IMAGE, Size(30, 30) );
            //此句代码的OpenCV3版为:
            eyes_cascade.detectMultiScale(faceROI, eyes, 1.1, 2, 0 | CASCADE_SCALE_IMAGE, Size(30, 30));
    
            for (size_t j = 0; j < eyes.size(); j++)
            {
                Point eye_center(faces[i].x + eyes[j].x + eyes[j].width / 2, faces[i].y + eyes[j].y + eyes[j].height / 2);
                int radius = cvRound((eyes[j].width + eyes[j].height) * 0.25);
                circle(frame, eye_center, radius, Scalar(255, 0, 0), 3, 8, 0);
            }
        }
        //-- 显示最终效果图
        imshow(window_name, frame);
    }
    ObjectDetection

    运行示例:(已打码!)

     2.5  train_svmgd

    #include "opencv2/core.hpp"
    #include "opencv2/video/tracking.hpp"
    #include "opencv2/imgproc.hpp"
    #include "opencv2/highgui.hpp"
    #include "opencv2/ml.hpp"
    
    using namespace cv;
    using namespace cv::ml;
    
    
    struct Data
    {
        Mat img;
        Mat samples;          //Set of train samples. Contains points on image
        Mat responses;        //Set of responses for train samples
    
        Data()
        {
            const int WIDTH = 841;
            const int HEIGHT = 594;
            img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
            imshow("Train svmsgd", img);
        }
    };
    
    //Train with SVMSGD algorithm
    //(samples, responses) is a train set
    //weights is a required vector for decision function of SVMSGD algorithm
    bool doTrain(const Mat samples, const Mat responses, Mat& weights, float& shift);
    
    //function finds two points for drawing line (wx = 0)
    bool findPointsForLine(const Mat& weights, float shift, Point points[], int width, int height);
    
    // function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
    bool findCrossPointWithBorders(const Mat& weights, float shift, const std::pair<Point, Point>& segment, Point& crossPoint);
    
    //segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
    void fillSegments(std::vector<std::pair<Point, Point> >& segments, int width, int height);
    
    //redraw points' set and line (wx = 0)
    void redraw(Data data, const Point points[2]);
    
    //add point in train set, train SVMSGD algorithm and draw results on image
    void addPointRetrainAndRedraw(Data& data, int x, int y, int response);
    
    
    bool doTrain(const Mat samples, const Mat responses, Mat& weights, float& shift)
    {
        cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
    
        cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
        svmsgd->train(trainData);
    
        if (svmsgd->isTrained())
        {
            weights = svmsgd->getWeights();
            shift = svmsgd->getShift();
    
            return true;
        }
        return false;
    }
    
    void fillSegments(std::vector<std::pair<Point, Point> >& segments, int width, int height)
    {
        std::pair<Point, Point> currentSegment;
    
        currentSegment.first = Point(width, 0);
        currentSegment.second = Point(width, height);
        segments.push_back(currentSegment);
    
        currentSegment.first = Point(0, height);
        currentSegment.second = Point(width, height);
        segments.push_back(currentSegment);
    
        currentSegment.first = Point(0, 0);
        currentSegment.second = Point(width, 0);
        segments.push_back(currentSegment);
    
        currentSegment.first = Point(0, 0);
        currentSegment.second = Point(0, height);
        segments.push_back(currentSegment);
    }
    
    
    bool findCrossPointWithBorders(const Mat& weights, float shift, const std::pair<Point, Point>& segment, Point& crossPoint)
    {
        int x = 0;
        int y = 0;
        int xMin = std::min(segment.first.x, segment.second.x);
        int xMax = std::max(segment.first.x, segment.second.x);
        int yMin = std::min(segment.first.y, segment.second.y);
        int yMax = std::max(segment.first.y, segment.second.y);
    
        CV_Assert(weights.type() == CV_32FC1);
        CV_Assert(xMin == xMax || yMin == yMax);
    
        if (xMin == xMax && weights.at<float>(1) != 0)
        {
            x = xMin;
            y = static_cast<int>(std::floor(-(weights.at<float>(0) * x + shift) / weights.at<float>(1)));
            if (y >= yMin && y <= yMax)
            {
                crossPoint.x = x;
                crossPoint.y = y;
                return true;
            }
        }
        else if (yMin == yMax && weights.at<float>(0) != 0)
        {
            y = yMin;
            x = static_cast<int>(std::floor(-(weights.at<float>(1) * y + shift) / weights.at<float>(0)));
            if (x >= xMin && x <= xMax)
            {
                crossPoint.x = x;
                crossPoint.y = y;
                return true;
            }
        }
        return false;
    }
    
    bool findPointsForLine(const Mat& weights, float shift, Point points[2], int width, int height)
    {
        if (weights.empty())
        {
            return false;
        }
    
        int foundPointsCount = 0;
        std::vector<std::pair<Point, Point> > segments;
        fillSegments(segments, width, height);
    
        for (uint i = 0; i < segments.size(); i++)
        {
            if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
                foundPointsCount++;
            if (foundPointsCount >= 2)
                break;
        }
    
        return true;
    }
    
    void redraw(Data data, const Point points[2])
    {
        data.img.setTo(0);
        Point center;
        int radius = 3;
        Scalar color;
        CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));
        for (int i = 0; i < data.samples.rows; i++)
        {
            center.x = static_cast<int>(data.samples.at<float>(i, 0));
            center.y = static_cast<int>(data.samples.at<float>(i, 1));
            color = (data.responses.at<float>(i) > 0) ? Scalar(128, 128, 0) : Scalar(0, 128, 128);
            circle(data.img, center, radius, color, 5);
        }
        line(data.img, points[0], points[1], cv::Scalar(1, 255, 1));
    
        imshow("Train svmsgd", data.img);
    }
    
    void addPointRetrainAndRedraw(Data& data, int x, int y, int response)
    {
        Mat currentSample(1, 2, CV_32FC1);
    
        currentSample.at<float>(0, 0) = (float)x;
        currentSample.at<float>(0, 1) = (float)y;
        data.samples.push_back(currentSample);
        data.responses.push_back(static_cast<float>(response));
    
        Mat weights(1, 2, CV_32FC1);
        float shift = 0;
    
        if (doTrain(data.samples, data.responses, weights, shift))
        {
            Point points[2];
            findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
    
            redraw(data, points);
        }
    }
    
    
    static void onMouse(int event, int x, int y, int, void* pData)
    {
        Data& data = *(Data*)pData;
    
        switch (event)
        {
        case EVENT_LBUTTONUP:
            addPointRetrainAndRedraw(data, x, y, 1);
            break;
    
        case EVENT_RBUTTONDOWN:
            addPointRetrainAndRedraw(data, x, y, -1);
            break;
        }
    
    }
    
    int main()
    {
        Data data;
    
        setMouseCallback("Train svmsgd", onMouse, &data);
        waitKey();
    
        return 0;
    }
    train_svmgd

    运行实例

    右键黄点,左键蓝点


    参考文献

    [1]  OpenCV4的官方实例.

    [2]  毛星云.OpenCV3编程入门[M].电子工业出版社.北京.2015.2.

  • 相关阅读:
    纪念--
    【csp模拟赛1】铁路网络 (network.cpp)
    【csp模拟赛1】不服来战 (challenge.cpp)
    【csp模拟赛1】T1 心有灵犀
    【luoguP3959 宝藏】-状压DP
    透彻网络流-wfx-最大流
    【luogu2668斗地主】模拟
    【hdu4734】F(x)-数位DP
    【8.27-模拟赛】remove
    清北学堂-济南游记
  • 原文地址:https://www.cnblogs.com/jianle23/p/13774293.html
Copyright © 2011-2022 走看看