zoukankan      html  css  js  c++  java
  • TrainData类型,拆分数据集setTrainTestSplitRatio(),计算准确率calcError()(OpenCV案例源码letter_recog.cpp解读3)

    机器学习中,需要总样本集,拆分成训练集、测试集,计算训练、测试、整体的准确率。

    OpenCV在ml.hpp中为我们准备了特有格式TrainData,它会把标签、特征集融合到其中,方便操作。

    针对TrainData类型,提供了非常完美的函数,具体介绍如下:

    1、拆分TrainData类型总样本集,注意默认是会打乱行顺序的。

    setTrainTestSplitRatio(double ratio, bool shuffle=true); //比例方式,前ratio(0~1)行是训练集,推荐使用此函数

    setTrainTestSplit(int count, bool shuffle=true); //具体指定方式,前count行是训练集
    第二个参数默认为true,即打乱行后再拆分。

    2、计算准确率,此函数计算的是错误率,准确率用100减下即可

    float calcError( const Ptr<TrainData>& data, bool test, OutputArray resp )
      data——融合了标签、特征的总样本集,TrainData类型
      test——计算训练集准确率设为false,测试集准确率设为true,总的准确率设为false同时data不可以使用拆分函数setTrainTestSplitRatio或setTrainTestSplit。
      resp——预测的标签结果,内部调用了predict函数

    3、获取训练集、测试集

    Mat getTrainSamples()
    Mat getTestSamples()

    【TrainData类型的解释】

    下图是letter-recognition.data文件的内容,第1列是标签,后16列是特征,需要分别读入到Mat格式的responses和samples中。之后用create函数创建TrainData类型对象(融合了标签、特征)。

    此解释针对的是分类问题,而不是回归拟合。samples特征集中每1行为1个样本,每1列为1个特征。
    Ptr<TrainData> create(InputArray samples, int layout, InputArray responses, InputArray varIdx=noArray(), InputArray sampleIdx=noArray(),InputArray sampleWeights=noArray(), InputArray varType=noArray());
      samples——特征集(特征变量组成的集合),必须是32FC1(32位浮点单通道)类型
      layout——样本布局,ROW_SAMPLE = 0,COL_SAMPLE = 1,此处只用前者
      responses——标签集,必须是CV_32S类型(即int型)。一维向量,与每一行样本对应。
      varIdx——参与训练的特征(默认都参与),元素为0、非0的一维向量(行列都可),大小为特征数(列数)。1对应的样本参与训练(非0有效),格式用CV_8U即可
      sampleIdx——参与训练的样本(默认都参与),同上,大小为总样本数(行数)
      sampleWeights——样本权重,忽略不用
      varType——特征变量的类型,忽略不用

    总样本集(上图)中responses在哪一列都可以,只要保证samples(行为样本,列为特征)、responses(行、列都可以)即可。

    将总样本集拆分为训练集和测试集的方法有两种:用自带函数(推荐),用参数sampleWeights(元素为0、非0的一维向量(行列都可))。

    方法一:使用加粗部分代码即可,setTrainTestSplit或setTrainTestSplitRatio、getTrainSamples、getTestSamples

    Ptr<TrainData> tdata=TrainData::create(samples, ROW_SAMPLE, responses);//其他参数默认,即所有样本包含在TrainData类对象tdata中
    //tdata->setTrainTestSplit(16000);//前16000行为训练集
    tdata->setTrainTestSplitRatio(0.8);//比例方式,前80%行作为训练集,推荐
    Mat trainSet = tdata->getTrainSamples();//获取训练集
    Mat testSet = tdata->getTestSamples();//获取测试集
    //Mat res = tdata->getResponses();//所有标签
    //Mat classNum = tdata->getClassLabels();//26类(分26类问题),65~90

    tdata是TrainData类型(包含了标签、特征),trainSet 、testSet 是Mat类型,只有特征。

    方法二:

    Mat sample_idx = Mat::zeros(1, samples.rows, CV_8U);
    Mat train_samples = sample_idx.colRange(0, (int)(data.rows*0.8)); //操作train_samples就是操作sample_idx,浅拷贝。sample_idx中前80%变为1    
    train_samples.setTo(Scalar::all(1));
    Ptr<TrainData> tdata=TrainData::create(samples, ROW_SAMPLE, responses, noArray(), sample_idx);//sample_idx中1对应的样本参与训练(非0有效)

    此时tdata中共有前80%的样本(既有特征又有标签)。

     【样本集注意点】

    如果样本集不同类别间是杂乱的,那么上述函数可以随意用。不过,一般我们会自己归类,如2分类问题。正样本、负样本各自独立有序。为了保证正负样本都会参与训练,所以最好不要拆分TrainData类型总样本集。否则可能会出现前80%行一种类别过多的情况。

    如果可以保证拆分后,训练集不同类别样本均衡,那可以使用训练集、测试集这种我们习惯的方式。否则,不推荐拆分。

  • 相关阅读:
    Linux下Java环境安装
    Go语言学习之10 Web开发与Mysql数据库
    Go语言学习之9 网络协议TCP、Redis与聊天室
    Redis入门指南之三(入门)
    Redis入门指南之一(简介)
    Go语言学习之8 goroutine详解、定时器与单元测试
    Redis入门指南之二(安装及配置)
    Go语言学习之7 接口实例、终端文件读写、异常处理
    Go语言学习之6 反射详解
    Go语言学习之5 进阶-排序、链表、二叉树、接口
  • 原文地址:https://www.cnblogs.com/xixixing/p/12512518.html
Copyright © 2011-2022 走看看