zoukankan      html  css  js  c++  java
  • 在opencv3中的机器学习算法

    在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种:

    1、正态贝叶斯:normal Bayessian classifier    我已在另外一篇博文中介绍过:在opencv3中实现机器学习之:利用正态贝叶斯分类

    2、K最近邻:k nearest neighbors classifier

    3、支持向量机:support vectors machine    请参考我的另外一篇博客:在opencv3中实现机器学习之:利用svm(支持向量机)分类

    4、决策树: decision tree

    5、ADA Boost:adaboost

    6、梯度提升决策树:gradient boosted trees

    7、随机森林:random forest

    8、人工神经网络:artificial neural networks

    9、EM算法:expectation-maximization

    这些算法在任何一本机器学习书本上都可以介绍过,他们大致的分类过程都很相似,主要分为三个环节:

    一、收集样本数据sampleData

    二、训练分类器mode

    三、对测试数据testData进行预测

    不同的地方就是在opencv中的参数设定,假设训练数据为trainingDataMat,且已经标注好labelsMat。待测数据为testMat.

    1、正态贝叶斯

     // 创建贝叶斯分类器
      Ptr<NormalBayesClassifier> model=NormalBayesClassifier::create();
        
        // 设置训练数据
      Ptr<TrainData> tData =TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
    
        //训练分类器
        model->train(tData);
    //预测数据
     float response = model->predict(testMat); 

    2、K最近邻

     Ptr<KNearest> knn = KNearest::create();  //创建knn分类器
        knn->setDefaultK(K);    //设定k值
        knn->setIsClassifier(true);
        // 设置训练数据
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        knn->train(tData);
        float response = knn->predict(testMat);

    3、支持向量机

    Ptr<SVM> svm = SVM::create();    //创建一个分类器
        svm->setType(SVM::C_SVC);    //设置svm类型
        svm->setKernel(SVM::POLY); //设置核函数;
        svm->setDegree(0.5);
        svm->setGamma(1);
        svm->setCoef0(1);
        svm->setNu(0.5);
        svm->setP(0);
        svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01));
        svm->setC(C);
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        svm->train(tData);
        float response = svm->predict(testMat);

    4、决策树: decision tree

    Ptr<DTrees> dtree = DTrees::create();  //创建分类器
        dtree->setMaxDepth(8);   //设置最大深度
        dtree->setMinSampleCount(2);  
        dtree->setUseSurrogates(false);
        dtree->setCVFolds(0); //交叉验证
        dtree->setUse1SERule(false);
        dtree->setTruncatePrunedTree(false);
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        dtree->train(tData);
        float response = dtree->predict(testMat);

    5、ADA Boost:adaboost

     Ptr<Boost> boost = Boost::create();
        boost->setBoostType(Boost::DISCRETE);
        boost->setWeakCount(100);
        boost->setWeightTrimRate(0.95);
        boost->setMaxDepth(2);
        boost->setUseSurrogates(false);
        boost->setPriors(Mat());
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        boost->train(tData);
        float response = boost->predict(testMat);

    6、梯度提升决策树:gradient boosted trees

    此算法在opencv3.0中被注释掉了,原因未知,因此此处提供一个老版本的算法。

    GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type
                             100, // weak_count
                             0.1f, // shrinkage
                             1.0f, // subsample_portion
                             2, // max_depth
                             false // use_surrogates )
                             );
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        Ptr<GBTrees> gbtrees = StatModel::train<GBTrees>(tData, params);
        float response = gbtrees->predict(testMat);

    7、随机森林:random forest

       Ptr<RTrees> rtrees = RTrees::create();
        rtrees->setMaxDepth(4);
        rtrees->setMinSampleCount(2);
        rtrees->setRegressionAccuracy(0.f);
        rtrees->setUseSurrogates(false);
        rtrees->setMaxCategories(16);
        rtrees->setPriors(Mat());
        rtrees->setCalculateVarImportance(false);
        rtrees->setActiveVarCount(1);
        rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 5, 0));
       Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
       rtrees->train(tData);
       float response = rtrees->predict(testMat);

    8、人工神经网络:artificial neural networks

     Ptr<ANN_MLP> ann = ANN_MLP::create();
        ann->setLayerSizes(layer_sizes);
        ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1, 1);
        ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON));
        ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001);
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        ann->train(tData);
        float response = ann->predict(testMat);

    9、EM算法:expectation-maximization

    EM算法与前面的稍微有点不同,它需要创建很多个model,将trainingDataMat分成很多个modelSamples,每个modelSamples训练出一个model

    训练核心代码为:

     int nmodels = (int)labelsMat.size();
        vector<Ptr<EM> > em_models(nmodels);
        Mat modelSamples;
    
        for( i = 0; i < nmodels; i++ )
        {
            const int componentCount = 3;
    
            modelSamples.release();
            for (j = 0; j < labelsMat.rows; j++)
            {
                if (labelsMat.at<int>(j,0)== i)
                    modelSamples.push_back(trainingDataMat.row(j));
            }
    
            // learn models
            if( !modelSamples.empty() )
            {
                Ptr<EM> em = EM::create();
                em->setClustersNumber(componentCount);
                em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL);
                em->trainEM(modelSamples, noArray(), noArray(), noArray());
                em_models[i] = em;
            }
        }

    预测:

     Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
     for( i = 0; i < nmodels; i++ )
                {
                    if( !em_models[i].empty() )
                        logLikelihoods.at<double>(i) = em_models[i]->predict2(testMat, noArray())[0];
                }

    这么多的机器学习算法,在实际用途中照我的理解其实只需要掌握svm算法就可以了。

    ANN算法在opencv中也叫多层感知机,因此在训练的时候,需要分多层。

    EM算法需要为每一类创建一个model。

    其中一些算法的具体代码练习:在opencv3中的机器学习算法练习:对OCR进行分类

  • 相关阅读:
    gcc/g++动态链接库和静态库的链接顺序
    linux文件描述符--转载
    mysql之test表
    linux之eventfd()
    spring data jpa实现多条件查询(分页和不分页)
    List的remove方法里的坑
    启动Spring boot报错:nested exception is java.sql.SQLException: Field 'id' doesn't have a default value
    CentOS下yum安装jdk
    CentOS下yum安装mysql
    Redis有序Set、无序Set的使用经历
  • 原文地址:https://www.cnblogs.com/denny402/p/5032232.html
Copyright © 2011-2022 走看看