zoukankan      html  css  js  c++  java
  • EM算法学习笔记_2(opencv自带EM sample学习)

      实验说明:

      在上一讲EM算法学习笔记_1(EM算法的简单理解) 中已经用通俗的语言简单的介绍了下EM算法,在这一节中就采用opencv自带的一个EM sample来学习下opencvEM 算法类的使用,顺便也体验下EM 算法的实际应用。

      环境:Ubuntu12.04+Qt4.8.2+QtCreator2.5+opencv2.4.2

      在这里需要使用2个与EM算法有关的类,即CvEMCvEMParams,这2个类在opencv2.4.2已经放入legacy文件夹中了,说明不久就会被淘汰掉,因为在未来的opencv版本中,将采用Algorithm这个公共类来统一接口。不过CvEMCvEMParams的使用与其类似,且可以熟悉EM算法的使用流程。

      需要注意的是这2个类虽然是与EM算法有关,可是只能解决GMM问题,比较局限。也许这是将其放在legacy中的原因吧。

     

      实验流程:

      首先产生需要聚类的样本数据,我这里采用的是9个混合的二维高斯分布,所以需要被聚类成9类,这些GMM排成3*3的格式,每一格25个点,共225个训练样本。在软件中显示出样本点的分布。

      用类EMCvEMParams初始化emem_params对象。

      设置EM参数类em_params的各个参数,这里的均值、权值、方差的初始化采用的是kmeans聚类得到的,em_params的参数中需要特别指定的是所聚类类别N(这里等于9.

      用这255个数据进行训练EM模型,采用的是CvEM类方法train()函数。

      把窗口大小500*500内的每个点用训练出来的EM模型进行预测,将预测结果用不同的颜色在软件中画出来。

      把训练过程中样本的类别标签(程序中保存在label中)在图像中显示出来。

     

      实验结果:

      软件界面图:

      

      按下Gnenrate Data按钮后显示如下:

      

      按下EM Cluster按钮后显示如下:

      


      实验代码

    mainwindow.h:

    #ifndef MAINWINDOW_H
    #define MAINWINDOW_H
    
    #include <QMainWindow>
    //#include <vector>
    #include <opencv2/core/core.hpp>
    #include <opencv2/ml/ml.hpp>
    #include <opencv2/highgui/highgui.hpp>
    #include <opencv2/legacy/legacy.hpp>
    
    using namespace cv;
    using namespace std;
    //using std::vector;
    
    namespace Ui {
    class MainWindow;
    }
    
    class MainWindow : public QMainWindow
    {
        Q_OBJECT
        
    public:
        explicit MainWindow(QWidget *parent = 0);
        ~MainWindow();
    
        vector<Scalar> colors;
        
    private slots:
    
        void on_closeButton_clicked();
    
        void on_generateButton_clicked();
    
        void on_clusterButton_clicked();
    
    private:
        Ui::MainWindow *ui;
    
        int nsamples;
        int N, N1;
        Mat img, img1;
        Mat samples, sample_predict;
        Mat labels;
        CvEM em;
        CvEMParams em_params;
    };
    
    #endif // MAINWINDOW_H

    mainwindow.cpp:

    #include "mainwindow.h"
    #include "ui_mainwindow.h"
    #include <QImage>
    
    MainWindow::MainWindow(QWidget *parent) :
        QMainWindow(parent),
        ui(new Ui::MainWindow)
    {
        ui->setupUi(this);
        N = 9;
        N1 = (int)sqrt(double(N));
        nsamples = 225;
        img = Mat( Size(500, 500), CV_8UC3 );
    
        colors.resize(N);
        colors.at(0) = Scalar(0, 255, 255);
        colors.at(1) = Scalar(255, 0, 255);
        colors.at(2) = Scalar(255, 255, 0);
        colors.at(3) = Scalar(255, 0, 0);
        colors.at(4) = Scalar(0, 255, 0);
        colors.at(5) = Scalar(0, 0, 255);
        colors.at(6) = Scalar(255, 100, 100);
        colors.at(7) = Scalar(100, 255, 100);
        colors.at(8) = Scalar(100, 100, 255);
    
    }
    
    MainWindow::~MainWindow()
    {
        delete ui;
    }
    
    
    void MainWindow::on_closeButton_clicked()
    {
        close();
    }
    
    void MainWindow::on_generateButton_clicked()
    {
        samples = Mat( nsamples, 2, CV_32FC1);//用来存储产生的二维随机点
        samples = samples.reshape( 2, 0 );//转换成2通道的矩阵,reshape函数只适应而2维图像
    
        //初始化样本
        for( int i = 0; i < N; i++ )
            {
                Mat sub_samples = samples.rowRange( i*nsamples/N, (i+1)*nsamples/N );
                Scalar mean( (i%N1+1)*img.rows/(N1+1), (i/N1+1)*img.rows/(N1+1));
                Scalar var( 30, 30 );
                randn( sub_samples, mean, var );
            }
        samples = samples.reshape( 1, 0 );
    
        //显示样本数据
        for( int j = 0; j < nsamples; j++ )
        {
            Point gene_sample;
            gene_sample.x = cvRound(samples.at<float>(j, 0));
            gene_sample.y = cvRound(samples.at<float>(j, 1));
            circle( img, gene_sample, 1, Scalar(0, 255, 250), 1, 8 );
        }
        cvtColor( img, img, CV_BGR2RGB );
    
        /*Qt中处理图像有4个类,分别为QImage,QPixmap,QBitmap,QPicture.其中QPixmap专门负责在屏幕上显示图片
        的,QImage专门负责和I/O方面的,QBitmap是从QPixmap中继承来的,只负责一个通道的图像处理,QPicture是
        专门用来负责画图的*/
        QImage qimg = QImage( img.data, img.cols, img.rows, QImage::Format_RGB888 );
        //setPixmap为QLabel发出的公共信号,fromImage函数为将图片转换程QPixmap的格式
        ui->imgLabel->setPixmap( QPixmap::fromImage( qimg ) );
    }
    
    void MainWindow::on_clusterButton_clicked()
    {
        //给EM算法参赛赋值,均值,方差和权值采用kmeans初步聚类得到
        em_params.means = NULL;
        em_params.covs = NULL;
        em_params.weights = NULL;
        em_params.nclusters = N;
        em_params.start_step = CvEM::START_AUTO_STEP;
        em_params.cov_mat_type = CvEM::COV_MAT_SPHERICAL;
        //达到最大迭代次数或者迭代误差小到一定值,应该有系统默认的值
        em_params.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
    
        cvtColor( img, img, CV_RGB2BGR );
    
        //EM算法训练过程
        em.train( samples, Mat(), em_params, &labels );
    
        //画出背景图
        sample_predict = Mat( 1, 2, CV_32FC1 );
        for( int i = 0; i < img.rows; i++ )
            for( int j = 0; j < img.cols; j++ )
                {
                    sample_predict.at<float>(0) = (float)i;
                    sample_predict.at<float>(1) = (float)j;
                    int value = cvRound(em.predict( sample_predict ));//返回的value为预测类标签
                    circle( img, Point(i, j), 1, 0.1*colors.at(value), 1, 8 );
                }
    
        //画出样本点的聚类情况
        for( int n = 0; n < nsamples; n++ )
            circle( img, Point(cvRound(samples.at<float>(n, 0)), cvRound(samples.at<float>(n, 1))),
                    1, colors.at( labels.at<int>(n)), 1, 8 );//因为此时labels保存的是类标签(1~N),为整型
    
        //显示图像
        cvtColor( img, img, CV_BGR2RGB );
        QImage qimg = QImage( img.data, img.cols, img.rows, QImage::Format_RGB888 );
        ui->imgLabel->setPixmap( QPixmap::fromImage(qimg) );
    
    
    }

    main.cpp:

    #include <QApplication>
    #include "mainwindow.h"
    
    int main(int argc, char *argv[])
    {
        QApplication a(argc, argv);
        MainWindow w;
        w.show();
        
        return a.exec();
    }

      实验总结:

      要学会数据点产生的类似方法,特别是reshape函数的使用方法。

      要学会用STL的vector,这个容器要比数组方便很多。

      要多学点C++的编程思想。

       附录:工程code下载地址


    作者:tornadomeet 出处:http://www.cnblogs.com/tornadomeet 欢迎转载或分享,但请务必声明文章出处。 (新浪微博:tornadomeet,欢迎交流!)
  • 相关阅读:
    linux tcpdump抓包,wireshark实时解析
    TLS协议分析
    sqlite sql语句关键字GROUP BY的理解
    使用 openssl 生成证书
    linux C单元测试工具CUnit的编译安装及使用
    http短连接大量time wait解决方案
    gdb调试行号错位
    libevent 多线程
    C语言单元测试
    客户端端口分配
  • 原文地址:https://www.cnblogs.com/tornadomeet/p/2592953.html
Copyright © 2011-2022 走看看