zoukankan      html  css  js  c++  java
  • 在opencv3中的机器学习算法

    在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种:

    1、正态贝叶斯:normal Bayessian classifier    我已在另外一篇博文中介绍过:在opencv3中实现机器学习之:利用正态贝叶斯分类

    2、K最近邻:k nearest neighbors classifier

    3、支持向量机:support vectors machine    请参考我的另外一篇博客:在opencv3中实现机器学习之:利用svm(支持向量机)分类

    4、决策树: decision tree

    5、ADA Boost:adaboost

    6、梯度提升决策树:gradient boosted trees

    7、随机森林:random forest

    8、人工神经网络:artificial neural networks

    9、EM算法:expectation-maximization

    这些算法在任何一本机器学习书本上都可以介绍过,他们大致的分类过程都很相似,主要分为三个环节:

    一、收集样本数据sampleData

    二、训练分类器mode

    三、对测试数据testData进行预测

    不同的地方就是在opencv中的参数设定,假设训练数据为trainingDataMat,且已经标注好labelsMat。待测数据为testMat.

    1、正态贝叶斯

     // 创建贝叶斯分类器
      Ptr<NormalBayesClassifier> model=NormalBayesClassifier::create();
        
        // 设置训练数据
      Ptr<TrainData> tData =TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
    
        //训练分类器
        model->train(tData);
    //预测数据
     float response = model->predict(testMat); 

    2、K最近邻

     Ptr<KNearest> knn = KNearest::create();  //创建knn分类器
        knn->setDefaultK(K);    //设定k值
        knn->setIsClassifier(true);
        // 设置训练数据
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        knn->train(tData);
        float response = knn->predict(testMat);

    3、支持向量机

    Ptr<SVM> svm = SVM::create();    //创建一个分类器
        svm->setType(SVM::C_SVC);    //设置svm类型
        svm->setKernel(SVM::POLY); //设置核函数;
        svm->setDegree(0.5);
        svm->setGamma(1);
        svm->setCoef0(1);
        svm->setNu(0.5);
        svm->setP(0);
        svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01));
        svm->setC(C);
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        svm->train(tData);
        float response = svm->predict(testMat);

    4、决策树: decision tree

    Ptr<DTrees> dtree = DTrees::create();  //创建分类器
        dtree->setMaxDepth(8);   //设置最大深度
        dtree->setMinSampleCount(2);  
        dtree->setUseSurrogates(false);
        dtree->setCVFolds(0); //交叉验证
        dtree->setUse1SERule(false);
        dtree->setTruncatePrunedTree(false);
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        dtree->train(tData);
        float response = dtree->predict(testMat);

    5、ADA Boost:adaboost

     Ptr<Boost> boost = Boost::create();
        boost->setBoostType(Boost::DISCRETE);
        boost->setWeakCount(100);
        boost->setWeightTrimRate(0.95);
        boost->setMaxDepth(2);
        boost->setUseSurrogates(false);
        boost->setPriors(Mat());
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        boost->train(tData);
        float response = boost->predict(testMat);

    6、梯度提升决策树:gradient boosted trees

    此算法在opencv3.0中被注释掉了,原因未知,因此此处提供一个老版本的算法。

    GBTrees::Params params( GBTrees::DEVIANCE_LOSS, // loss_function_type
                             100, // weak_count
                             0.1f, // shrinkage
                             1.0f, // subsample_portion
                             2, // max_depth
                             false // use_surrogates )
                             );
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        Ptr<GBTrees> gbtrees = StatModel::train<GBTrees>(tData, params);
        float response = gbtrees->predict(testMat);

    7、随机森林:random forest

       Ptr<RTrees> rtrees = RTrees::create();
        rtrees->setMaxDepth(4);
        rtrees->setMinSampleCount(2);
        rtrees->setRegressionAccuracy(0.f);
        rtrees->setUseSurrogates(false);
        rtrees->setMaxCategories(16);
        rtrees->setPriors(Mat());
        rtrees->setCalculateVarImportance(false);
        rtrees->setActiveVarCount(1);
        rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 5, 0));
       Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
       rtrees->train(tData);
       float response = rtrees->predict(testMat);

    8、人工神经网络:artificial neural networks

     Ptr<ANN_MLP> ann = ANN_MLP::create();
        ann->setLayerSizes(layer_sizes);
        ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1, 1);
        ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON));
        ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001);
        Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
        ann->train(tData);
        float response = ann->predict(testMat);

    9、EM算法:expectation-maximization

    EM算法与前面的稍微有点不同,它需要创建很多个model,将trainingDataMat分成很多个modelSamples,每个modelSamples训练出一个model

    训练核心代码为:

     int nmodels = (int)labelsMat.size();
        vector<Ptr<EM> > em_models(nmodels);
        Mat modelSamples;
    
        for( i = 0; i < nmodels; i++ )
        {
            const int componentCount = 3;
    
            modelSamples.release();
            for (j = 0; j < labelsMat.rows; j++)
            {
                if (labelsMat.at<int>(j,0)== i)
                    modelSamples.push_back(trainingDataMat.row(j));
            }
    
            // learn models
            if( !modelSamples.empty() )
            {
                Ptr<EM> em = EM::create();
                em->setClustersNumber(componentCount);
                em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL);
                em->trainEM(modelSamples, noArray(), noArray(), noArray());
                em_models[i] = em;
            }
        }

    预测:

     Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
     for( i = 0; i < nmodels; i++ )
                {
                    if( !em_models[i].empty() )
                        logLikelihoods.at<double>(i) = em_models[i]->predict2(testMat, noArray())[0];
                }

    这么多的机器学习算法,在实际用途中照我的理解其实只需要掌握svm算法就可以了。

    ANN算法在opencv中也叫多层感知机,因此在训练的时候,需要分多层。

    EM算法需要为每一类创建一个model。

    其中一些算法的具体代码练习:在opencv3中的机器学习算法练习:对OCR进行分类

  • 相关阅读:
    LOJ 6089 小Y的背包计数问题 —— 前缀和优化DP
    洛谷 P1969 积木大赛 —— 水题
    洛谷 P1965 转圈游戏 —— 快速幂
    洛谷 P1970 花匠 —— DP
    洛谷 P1966 火柴排队 —— 思路
    51Nod 1450 闯关游戏 —— 期望DP
    洛谷 P2312 & bzoj 3751 解方程 —— 取模
    洛谷 P1351 联合权值 —— 树形DP
    NOIP2007 树网的核
    平面最近点对(加强版)
  • 原文地址:https://www.cnblogs.com/denny402/p/5032232.html
Copyright © 2011-2022 走看看