zoukankan      html  css  js  c++  java
  • opencv3.0机器学习算法使用

    //随机树分类
    Ptr<StatModel> lpmlBtnClassify::buildRtreesClassifier(Mat data, Mat responses, int ntrain_samples)
    {

    Ptr<RTrees> model;
    Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
    model = RTrees::create();
    model->setMaxDepth(10);
    model->setMinSampleCount(10);
    model->setRegressionAccuracy(0);
    model->setUseSurrogates(false);
    model->setMaxCategories(15);
    model->setPriors(Mat());
    model->setCalculateVarImportance(false);
    model->setTermCriteria(setIterCondition(100, 0.01f));
    model->train(tdata);

    return model;
    }

    //adaboost分类
    Ptr<StatModel> lpmlBtnClassify::buildAdaboostClassifier(Mat data, Mat responses, int ntrain_samples,int param0)
    {
    Mat weak_responses;
    int i, j, k;
    Ptr<Boost> model;

    int nsamples_all = data.rows;
    int var_count = data.cols;

    Mat new_data(ntrain_samples*class_count, var_count + 1, CV_32F);
    Mat new_responses(ntrain_samples*class_count, 1, CV_32S);

    for (i = 0; i < ntrain_samples; i++)
    {
    const float* data_row = data.ptr<float>(i);
    for (j = 0; j < class_count; j++)
    {
    float* new_data_row = (float*)new_data.ptr<float>(i*class_count + j);
    memcpy(new_data_row, data_row, var_count*sizeof(data_row[0]));
    new_data_row[var_count] = (float)j;
    new_responses.at<int>(i*class_count + j) = responses.at<int>(i) == j;
    }
    }

    Mat var_type(1, var_count + 2, CV_8U);
    var_type.setTo(Scalar::all(VAR_ORDERED));
    var_type.at<uchar>(var_count) = var_type.at<uchar>(var_count + 1) = VAR_CATEGORICAL;

    Ptr<TrainData> tdata = TrainData::create(new_data, ROW_SAMPLE, new_responses,
    noArray(), noArray(), noArray(), var_type);

    model = Boost::create();
    model->setBoostType(Boost::GENTLE);
    model->setWeakCount(param0);
    model->setWeightTrimRate(0.95);
    model->setMaxDepth(5);
    model->setUseSurrogates(false);
    model->train(tdata);

    return model;
    }

    //多层感知机分类(ANN)
    Ptr<StatModel> lpmlBtnClassify::buildMlpClassifier(Mat data, Mat responses, int ntrain_samples)
    {
    //read_num_class_data(data_filename, 16, &data, &responses);
    Ptr<ANN_MLP> model;
    Mat train_data = data.rowRange(0, ntrain_samples);
    Mat train_responses = Mat::zeros(ntrain_samples, class_count, CV_32F);

    // 1. unroll the responses
    for (int i = 0; i < ntrain_samples; i++)
    {
    int cls_label = responses.at<int>(i);
    train_responses.at<float>(i, cls_label) = 1.f;
    }

    // 2. train classifier
    int layer_sz[] = { data.cols, 100, 100, class_count };
    int nlayers = (int)(sizeof(layer_sz) / sizeof(layer_sz[0]));
    Mat layer_sizes(1, nlayers, CV_32S, layer_sz);

    #if 1
    int method = ANN_MLP::BACKPROP;
    double method_param = 0.001;
    int max_iter = 300;
    #else
    int method = ANN_MLP::RPROP;
    double method_param = 0.1;
    int max_iter = 1000;
    #endif

    Ptr<TrainData> tdata = TrainData::create(train_data, ROW_SAMPLE, train_responses);
    model = ANN_MLP::create();
    model->setLayerSizes(layer_sizes);
    model->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
    model->setTermCriteria(setIterCondition(max_iter, 0));
    model->setTrainMethod(method, method_param);
    model->train(tdata);
    return model;
    }


    //贝叶斯分类
    Ptr<StatModel> lpmlBtnClassify::buildNbayesClassifier(Mat data, Mat responses, int ntrain_samples)
    {
    Ptr<NormalBayesClassifier> model;
    Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
    model = NormalBayesClassifier::create();
    model->train(tdata);

    return model;
    }

    Ptr<StatModel> lpmlBtnClassify::buildKnnClassifier(Mat data, Mat responses, int ntrain_samples, int K)
    {
    Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
    Ptr<KNearest> model = KNearest::create();
    model->setDefaultK(K);
    model->setIsClassifier(true);
    model->train(tdata);

    return model;
    }

    //svm分类
    Ptr<StatModel> lpmlBtnClassify::buildSvmClassifier(Mat data, Mat responses, int ntrain_samples)
    {
    Ptr<SVM> model;
    Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
    model = SVM::create();
    model->setType(SVM::C_SVC);
    model->setKernel(SVM::RBF);
    model->setC(1);
    model->train(tdata);
    return model;
    }

  • 相关阅读:
    mysql定时器,定时查询数据库,把查询结果插入到一张表中 阿星小栈
    如何写mysql的定时任务 阿星小栈
    利用mysql游标循环结果集 阿星小栈
    页面可见生Page Visibility
    css之z-index
    css之页面三列布局之左右两边宽度固定,中间自适应
    css之页面两列布局
    jquery源码学习之extend
    jquery源码学习之queue方法
    HTTP状态码详解
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/9866502.html
Copyright © 2011-2022 走看看