zoukankan      html  css  js  c++  java
  • 网格搜索(参数选择)

    首先说交叉验证。
    交叉验证(Cross validation)是一种评估统计分析、机器学习算法对独立于训练数据的数据集的泛化能力(generalize), 能够避免过拟合问题。
    交叉验证一般要尽量满足:
    1)训练集的比例要足够多,一般大于一半
    2)训练集和测试集要均匀抽样

    交叉验证主要分成以下几类:

    1)Double cross-validation
    Double cross-validation也称2-fold cross-validation(2-CV),作法是将数据集分成两个相等大小的子集,进行两回合的分类器训练。在第一回合中,一个子集作为训练集,另一个作为测试集;在第二回合中,则将训练集与测试集对换后,再次训练分类器,而其中我们比较关心的是两次测试集的识别率。不过在实际中2-CV并不常用,主要原因是训练集样本数太少,通常不足以代表母体样本的分布,导致测试阶段识别率容易出现明显落差。此外,2-CV中子集的变异度大,往往无法达到「实验过程必须可以被复制」的要求。

    2)k-folder cross-validation(k折交叉验证)
    K-fold cross-validation (k-CV)则是Double cross-validation的延伸,做法是将数据集分成k个子集,每个子集均做一次测试集,其余的作为训练集。k-CV交叉验证重复k次,每次选择一个子集作为测试集,并将k次的平均交叉验证识别率作为结果。
    优点:所有的样本都被作为了训练集和测试集,每个样本都被验证一次。10-folder通常被使用。

    3)leave-one-out cross-validation(LOOCV留一验证法)
    假设数据集中有n个样本,那LOOCV也就是n-CV,意思是每个样本单独作为一次测试集,剩余n-1个样本则做为训练集。
    优点:
    1)每一回合中几乎所有的样本皆用于训练model,因此最接近母体样本的分布,估测所得的generalization error比较可靠。 因此在实验数据集样本较少时,可以考虑使用LOOCV。
    2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。
    但LOOCV的缺点则是计算成本高,为需要建立的models数量与总样本数量相同,当总样本数量相当多时,LOOCV在实作上便有困难,除非每次训练model的速度很快,或是可以用平行化计算减少计算所需的时间。

    libsvm提供了 void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target)方法,参数含义如下:

    prob:待解决的分类问题,就是样本数据。
    param:svm训练参数。
    nr_fold:顾名思义就是k折交叉验证中的k,如果k=n的话就是留一法了。
    target:预测值,如果是分类问题的话就是类别标签了。

    然后我们讨论下参数选择。
    使用svm,无论是libsvm还是svmlight,都需要对参数进行设置。以RBF核为例,在《A Practical Guide to Support Vector Classi cation》一文中作者提到在RBF核中有2个参数:C和g。对于一个给定的问题,我们事先不知道C和g取多少最优,因此我们要进行模型选择(参数搜索)。这样做的目标是找到好的(C, g)参数对,使得分类器能够精确地预测未知的数据,比如测试集。需要注意的是在在训练集上追求高精确度可能是没用的(意指泛化能力)。根据前一部分所说的,衡量泛化能力要用到交叉验证。

    在文章中作者推荐使用“网格搜索”来寻找最优的C和g。所谓的网格搜索就是尝试各种可能的(C, g)对值,然后进行交叉验证,找出使交叉验证精确度最高的(C, g)对。“网格搜索”的方法很直观但是看起来有些原始。事实上有许多高级的算法,比如可以使用一些近似算法或启发式的搜索来降低复杂度。但是我们倾向于使用“网格搜索”这一简单的方法:
    1)从心理上讲,不进行全面的参数搜索而是使用近似算法或启发式算法让人感觉不安全。
    2)如果参数比较少,“网格搜索”的复杂度比高级算法高不了多少。
    3)“网格搜索”可并行性高,因为每个(C, g)对是相互独立的。

    说了那么大半天,其实“网格搜索”就是n层循环,n是参数个数,仍然以RBF核为例,编程实现如下:

    for(double c=c_begin;c<c_end;c+=c_step)
    {
    for(double g=g_begin;g<g_end;g+=g_step)
    {
    //这里进行交叉验证,计算精确度。
    }
    }

    通过上述两层循环找到最优的C和g就可以了。

    附录:
    使用Cross-Validation时常犯的错误
    由于实验室许多研究都有用到evolutionary algorithms(EA)与classifiers,所使用的fitness function中通常都有用到classifier的辨识率,然而把cross-validation用错的案例还不少。前面说过,只有training data才可以用于model的建构,所以只有training data的辨识率才可以用在fitness function中。而EA是训练过程用来调整model最佳参数的方法,所以只有在EA结束演化后,model参数已经固定了,这时候才可以使用test data。(当然如果想造假的话就把测试集的数据参与进模型训练,这样得到的模型效果多少会好些,因为模型本身已经包含了测试集的先验知识,测试集对它来说不再是未知数据。)

    那EA跟cross-validation要如何搭配呢?Cross-validation的本质是用来估测(estimate)某个classification method对一组dataset的generalization error,不是用来设计classifier的方法,所以cross-validation不能用在EA的fitness function中,因为与fitness function有关的样本都属于training set,那试问哪些样本才是test set呢?如果某个fitness function中用了cross-validation的training或test辨识率,那么这样的实验方法已经不能称为 cross-validation了。

    EA与k-CV正确的搭配方法,是将dataset分成k等份的subsets后,每次取1份 subset作为test set,其余k-1份作为training set,并且将该组training set套用到EA的fitness function计算中(至于该training set如何进一步利用则没有限制)。因此,正确的k-CV 会进行共k次的EA演化,建立k个classifiers。而k-CV的test辨识率,则是k组test sets对应到EA训练所得的k个classifiers辨识率之平均值。

  • 相关阅读:
    navicat 连接 mysql 出现Client does not support authentication protocol requested by server解决方案
    oracle的concat、convert、listagg函数(字符串拼接和类型转换)
    oracle的decode、sign、nvl,case...then函数
    where、having区别
    Oracle的rollup、cube、grouping sets函数
    IP地址,子网掩码,网段表示法,默认网关,DNS服务器详解,DNS域名设计
    springmvc 参数解析绑定原理
    eclipse运行mapreduce的wordcount
    linux命令帮助 man bash
    shell学习笔记3-后台执行命令
  • 原文地址:https://www.cnblogs.com/sddai/p/6440797.html
Copyright © 2011-2022 走看看