一、简介
交叉验证(Cross validation,简称CV)是在机器学习建立模型和验证模型参数时常用的办法,一般被用于评估一个机器学习模型的表现。交叉验证的基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train set),另一部分做为验证集(validation set or test set),首先用训练集对分类器进行训练,再利用验证集来测试训练得到的模型(model),以此来做为评价分类器的性能指标。常见CV的方法有Holdout 验证、K-fold cross-validation、留一验证。
1. Holdout 验证
方法:将原始数据随机分为两组,一组作为训练集,一组作为验证集。利用训练集训练分类器,利用验证集验证模型,记录最后的分类准确率为此Hold-Out Method下分类器的性能指标。
优缺点:此方法的好处是处理简单,只需要随机把原始数据分为两组即可。但是,Holdout 验证严格意义上并不能算是CV,因为这种方法并没有“交叉”的思想。由于是随机地将原始数据分组,所以最后验证集分类准确率的高低与原始数据的分组有很大的关系,所以这种方法得到的结果其实并不具有说服性。
2. K-fold cross-validation
方法:将原始数据分割成K组数据集,每个单独的数据集作为验证集,其余的K-1个数据集用来训练,交叉验证重复K次,共得到K个模型,用这K个模型最终的验证集的分类准确率的平均数作为此K-CV下分类器的性能指标。K一般大于等于2,实际操作时一般从3开始取,一般而言K=10是最常用的。
优缺点:K-CV作为方法1的演进,可以有效地避免过学习以及欠学习状态的发生,最后得到的结果也比较具有说服性。其主要缺点在于K值的选取上。
3. Leave-One_Out Cross Validation(LOO-CV)
方法:如果设原始数据有N个样本,那么LOO-CV就是N-CV,即每个样本单独作为验证集,其余的N-1个样本作为训练集,所以LOO-CV会得到N个模 型,用这N个模型最终的验证集的分类准确率的平均数作为此下LOO-CV分类器的性能指标。
优点:相比于前面的K-CV,LOO-CV有两个明显的优点:1)每一回合中几乎所有的样本皆用于训练模型,因此最接近原始样本的分布,这样评估所得的结果比较可靠。2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。但LOO-CV的缺点则是计算成本高,因为需要建立的模型数量与原始数据样本数量相同,当原始数据样本数量相当多时,LOO-CV在实作上便有困难几乎就是不显示,除非每次训练分类器得到模型的速度很快,或是可以用并行化计算减少计算所需的时间。
二、 MATLAB实践(K-CV)
在使用svm时,需要采用交叉验证选择最佳参数c和g。libsvm中的svmtrain函数内置交叉验证选项,svmtrain的options如下:
-s svm类型:SVM模型设置类型(默认值为0) 0:C - SVC 1:nu - SVC 2:one - class SVM 3: epsilon - SVR 4: nu - SVR - t 核函数类型:核函数设置类型(默认值为2) 0:线性核函数 u'v 1:多项式核函数(r *u'v + coef0)^degree 2:RBF 核函数 exp( -r|u - v|^2) 3:sigmiod核函数 tanh(r * u'v + coef0) - d degree:核函数中的 degree 参数设置(针对多项式核函数,默认值为3) - g r(gama):核函数中的gama参数设置(针对多项式/sigmoid 核函数/RBF/,默认值为属性数目的倒数) - r coef0:核函数中的coef0参数设置(针对多项式/sigmoid核函数,默认值为0) - c cost:设置 C - SVC,epsilon - SVR 和 nu - SVR的参数(默认值为1) - n nu:设置 nu-SVC ,one - class SVM 和 nu - SVR的参数 - p epsilon:设置 epsilon - SVR 中损失函数的值(默认值为0.1) - m cachesize:设置 cache 内存大小,以 MB 为单位(默认值为100) - e eps:设置允许的终止阈值(默认值为0.001) - h shrinking:是否使用启发式,0或1(默认值为1) - wi weight:设置第几类的参数 C 为 weight * C(对于 C - SVC 中的 C,默认值为1) - v n:n - fold 交互检验模式,n为折数,必须大于等于2
其中,-v 随机地将数据分为n部分,并计算交互检验准确度和均方根误差。大致实现代码如下:
% Cross_Validation % K-fold cross validation % Author Ethan % Date 2020/4/10 % Version 1.0 fprintf('Beginning crossvalidation ') crossval_start = tic; %best_accuracy = 0; best_cv = 0; best_c = 0; best_g = 0; k = 3; % number of folds for log2c = -5:5 for log2g = -5:5 cmd = ['-t 0','-v ',num2str(k),'-c ',num2str(2^log2c),'-g ',num2str(2^log2g)]; cv = svmtrain(labels,train_matrix,cmd); if cv >= best_cv best_cv = cv; best_c = 2^log2c; best_g = 2^log2g; end end end crossval_elapsed = toc(crossval_start); fprintf('SVM crosvalidation done in: %f seconds. ',crossval_elapsed); fprintf('Best crossval reached: %d, with cost=%d ', best_cv, best_c); %svm_params = ['-t ',num2str(0) ,' -c ', num2str(best_c),' -b 1'];