zoukankan      html  css  js  c++  java
  • OpenCV实现KNN算法

    原文 OpenCV实现KNN算法

    K Nearest Neighbors

    这个算法首先贮藏所有的训练样本,然后通过分析(包括选举,计算加权和等方式)一个新样本周围K个最近邻以给出该样本的相应值。这种方法有时候被称作“基于样本的学习”,即为了预测,我们对于给定的输入搜索最近的已知其相应的特征向量。

    class CvKNearest : public CvStatModel //继承自ML库中的统计模型基类
    {
    public:
      
        CvKNearest();//无参构造函数
        virtual ~CvKNearest();  //虚函数定义
      
        CvKNearest( const CvMat* _train_data, const CvMat* _responses,
                    const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );//有参构造函数
      
        virtual bool train( const CvMat* _train_data, const CvMat* _responses,
                            const CvMat* _sample_idx=0, bool is_regression=false,
                            int _max_k=32, bool _update_base=false );
      
        virtual float find_nearest( const CvMat* _samples, int k, CvMat* results,
            const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
      
        virtual void clear();
        int get_max_k() const;
        int get_var_count() const;
        int get_sample_count() const;
        bool is_regression() const;
      
    protected:
        ...
    };

    CvKNearest::train

    训练KNN模型

    bool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses,
                            const CvMat* _sample_idx=0, bool is_regression=false,
                            int _max_k=32, bool _update_base=false );

    这个类的方法训练K近邻模型。它遵循一个一般训练方法约定的限制:只支持CV_ROW_SAMPLE数据格式,输入向量必须都是有序的,而输出可以 是 无序的(当is_regression=false),可以是有序的(is_regression=true)。并且变量子集和省略度量是不被支持的。

    参数_max_k 指定了最大邻居的个数,它将被传给方法find_nearest。 参数 _update_base 指定模型是由原来的数据训练(_update_base=false),还是被新训练数据更新后再训练(_update_base=true)。在后一种情况下_max_k 不能大于原值, 否则它会被忽略.

    CvKNearest::find_nearest

    寻找输入向量的最近邻

    float CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* results=0,
            const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;

    对每个输入向量(表示为matrix_sample的每一行),该方法找到k(k≤get_max_k() )个最近邻。在回归中,预测结果将是指定向量的近邻的响应的均值。在分类中,类别将由投票决定。

    对传统分类和回归预测来说,该方法可以有选择的返回近邻向量本身的指针(neighbors, array of k*_samples->rows pointers),它们相对应的输出值(neighbor_responses, a vector of k*_samples->rows elements) ,和输入向量与近邻之间的距离(dist, also a vector of k*_samples->rows elements)。

    对每个输入向量来说,近邻将按照它们到该向量的距离排序。

    对单个输入向量,所有的输出矩阵是可选的,而且预测值将由该方法返回。

    例程:使用kNN进行2维样本集的分类,样本集的分布为混合高斯分布

    #include "ml.h"
    #include "highgui.h"
      
    int main( int argc, char** argv )
    {
        const int K = 10;
        int i, j, k, accuracy;
        float response;
        int train_sample_count = 100;
        CvRNG rng_state = cvRNG(-1);
        CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );
        CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );
        IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
        float _sample[2];
        CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );
        cvZero( img );
      
        CvMat trainData1, trainData2, trainClasses1, trainClasses2;
      
        // form the training samples
        cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );
        cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );
      
        cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );
        cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );
      
        cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );
        cvSet( &trainClasses1, cvScalar(1) );
      
        cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );
        cvSet( &trainClasses2, cvScalar(2) );
      
        // learn classifier
        CvKNearest knn( trainData, trainClasses, 0, false, K );
        CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);
      
        for( i = 0; i < img->height; i++ )
        {
            for( j = 0; j < img->width; j++ )
            {
                sample.data.fl[0] = (float)j;
                sample.data.fl[1] = (float)i;
      
                // estimates the response and get the neighbors' labels
                response = knn.find_nearest(&sample,K,0,0,nearests,0);
      
                // compute the number of neighbors representing the majority
                for( k = 0, accuracy = 0; k < K; k++ )
                {
                    if( nearests->data.fl[k] == response)
                        accuracy++;
                }
                // highlight the pixel depending on the accuracy (or confidence)
                cvSet2D( img, i, j, response == 1 ?
                    (accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :
                    (accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );
            }
        }
      
        // display the original training samples
        for( i = 0; i < train_sample_count/2; i++ )
        {
            CvPoint pt;
            pt.x = cvRound(trainData1.data.fl[i*2]);
            pt.y = cvRound(trainData1.data.fl[i*2+1]);
            cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );
            pt.x = cvRound(trainData2.data.fl[i*2]);
            pt.y = cvRound(trainData2.data.fl[i*2+1]);
            cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );
        }
      
        cvNamedWindow( "classifier result", 1 );
        cvShowImage( "classifier result", img );
        cvWaitKey(0);
      
        cvReleaseMat( &trainClasses );
        cvReleaseMat( &trainData );
        return 0;
    }

    结果:

  • 相关阅读:
    Lock wait timeout exceeded; try restarting transaction linux设置mysql innodb_lock_wait_timeout
    用NaviCat创建存储过程批量添加测试数据
    mysql存储过程语法及实例
    mysql中迅速插入百万条测试数据的方法
    mysql学习之通过文件创建数据库以及添加数据
    有用的网站集合
    VMware Workstation虚拟磁盘文件备份或移植
    CoreData修改了数据模型报错 The model used to open the store is incompatible with the one used to create the store
    iOS中自定义UITableViewCell的用法
    golang make()的第三个参数
  • 原文地址:https://www.cnblogs.com/arxive/p/6215154.html
Copyright © 2011-2022 走看看