zoukankan      html  css  js  c++  java
  • 几种交叉验证(cross validation)方式的比较

    模型评价的目的:通过模型评价,我们知道当前训练模型的好坏,泛化能力如何?从而知道是否可以应用在解决问题上,如果不行,那又是哪里出了问题?

    train_test_split

    在分类问题中,我们通常通过对训练集进行train_test_split,划分成train 和test 两部分,其中train用来训练模型,test用来评估模型,模型通过fit方法从train数据集中学习,然后调用score方法在test集上进行评估,打分;从分数上我们可以知道 模型当前的训练水平如何。

    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    
    cancer = load_breast_cancer()
    X_train,X_test,y_train,y_test = train_test_split(cancer.data,cancer.target,random_state=0)
    
    logreg = LogisticRegression().fit(X_train,y_train)
    print("Test set score:{:.2f}".format(logreg.score(X_test,y_test)))
    

    输出:
    output: Test set score:0.96
    然而,这种方式存:只进行了一次划分,数据结果具有偶然性,如果在某次划分中,训练集里全是容易学习的数据,测试集里全是复杂的数据,这样就会导致最终的结果不尽如意;反之,亦是如此。

    Standard Cross Validation

    针对上面通过train_test_split划分,从而进行模型评估方式存在的弊端,提出Cross Validation 交叉验证。
    Cross Validation:简言之,就是进行多次train_test_split划分;每次划分时,在不同的数据集上进行训练、测试评估,从而得出一个评价结果;如果是5折交叉验证,意思就是在原始数据集上,进行5次划分,每次划分进行一次训练、评估,最后得到5次划分后的评估结果,一般在这几次评估结果上取平均得到最后的 评分。k-fold cross-validation ,其中,k一般取5或10。
    标准交叉验证standard Cross validation

    demo:

    from sklearn.model_selection import cross_val_score
    
    logreg = LogisticRegression()
    scores = cross_val_score(logreg,cancer.data, cancer.target) #cv:默认是3折交叉验证,可以修改cv=5,变成5折交叉验证。
    print("Cross validation scores:{}".format(scores))
    print("Mean cross validation score:{:2f}".format(scores.mean()))
    

    输出:

    Cross validation scores:[0.93684211 0.96842105 0.94179894]
    Mean cross validation score:0.949021
    

    交叉验证的优点:

    • 原始采用的train_test_split方法,数据划分具有偶然性;交叉验证通过多次划分,大大降低了这种由一次随机划分带来的偶然性,同时通过多次划分,多次训练,模型也能遇到各种各样的数据,从而提高其泛化能力;
    • 与原始的train_test_split相比,对数据的使用效率更高。train_test_split,默认训练集、测试集比例为3:1,而对交叉验证来说,如果是5折交叉验证,训练集比测试集为4:1;10折交叉验证训练集比测试集为9:1。数据量越大,模型准确率越高!

    缺点:

    • 这种简答的交叉验证方式,从上面的图片可以看出来,每次划分时对数据进行均分,设想一下,会不会存在一种情况:数据集有5类,抽取出来的也正好是按照类别划分的5类,也就是说第一折全是0类,第二折全是1类,等等;这样的结果就会导致,模型训练时,没有学习到测试集中数据的特点,从而导致模型得分很低,甚至为0,!为了避免这种情况,又出现了其他的各种交叉验证方式。

    Stratified k-fold cross validation

    分层交叉验证(Stratified k-fold cross validation):首先它属于交叉验证类型,分层的意思是说在每一折中都保持着原始数据中各个类别的比例关系,比如说:原始数据有3类,比例为1:2:1,采用3折分层交叉验证,那么划分的3折中,每一折中的数据类别保持着1:2:1的比例,这样的验证结果更加可信。
    通常情况下,可以设置cv参数来控制几折,但是我们希望对其划分等加以控制,所以出现了KFold,KFold控制划分折,可以控制划分折的数目,是否打乱顺序等,可以赋值给cv,用来控制划分。
    标准交叉验证 vs 分层交叉验证

    demo:

    from sklearn.datasets import load_iris
    from sklearn.model_selection import StratifiedKFold,cross_val_score
    from sklearn.linear_model import LogisticRegression
    
    iris = load_iris()
    print('Iris labels:
    {}'.format(iris.target))
    logreg = LogisticRegression()
    strKFold = StratifiedKFold(n_splits=3,shuffle=False,random_state=0)
    scores = cross_val_score(logreg,iris.data,iris.target,cv=strKFold)
    print("straitified cross validation scores:{}".format(scores))
    print("Mean score of straitified cross validation:{:.2f}".format(scores.mean()))
    

    输出:

    Iris labels:
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
    1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
    2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
    2 2]
    straitified cross validation scores:[0.96078431 0.92156863 0.95833333]
    Mean score of straitified cross validation:0.95
    

    Leave-one-out Cross-validation 留一法

    留一法Leave-one-out Cross-validation:是一种特殊的交叉验证方式。顾名思义,如果样本容量为n,则k=n,进行n折交叉验证,每次留下一个样本进行验证。主要针对小样本数据。

    demo:

    from sklearn.datasets import load_iris
    from sklearn.model_selection import LeaveOneOut,cross_val_score
    from sklearn.linear_model import LogisticRegression
    
    iris = load_iris()
    print('Iris labels:
    {}'.format(iris.target))
    logreg = LogisticRegression()
    loout = LeaveOneOut()
    scores = cross_val_score(logreg,iris.data,iris.target,cv=loout)
    print("leave-one-out cross validation scores:{}".format(scores))
    print("Mean score of leave-one-out cross validation:{:.2f}".format(scores.mean()))
    

    输出:

    Iris labels:
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
    0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
    1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
    2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
    2 2]
    leave-one-out cross validation scores:[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
    1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
    1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1.
    1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
    1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
    1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
    1. 1. 1. 1. 1. 1.]
    Mean score of leave-one-out cross validation:0.95
    

    Shuffle-split cross-validation

    控制更加灵活:可以控制划分迭代次数、每次划分时测试集和训练集的比例(也就是说:可以存在既不在训练集也不再测试集的情况);
    shuffle split cross validation
    demo:

    from sklearn.datasets import load_iris
    from sklearn.model_selection import ShuffleSplit,cross_val_score
    from sklearn.linear_model import LogisticRegression
    
    iris = load_iris()
    shufspl = ShuffleSplit(train_size=.5,test_size=.4,n_splits=8) #迭代8次;
    logreg = LogisticRegression()
    scores = cross_val_score(logreg,iris.data,iris.target,cv=shufspl)
    
    print("shuffle split cross validation scores:
    {}".format(scores))
    print("Mean score of shuffle split cross validation:{:.2f}".format(scores.mean()))
    

    输出:

    shuffle split cross validation scores:
    [0.95      0.95      0.95      0.95      0.93333333 0.96666667
    0.96666667 0.91666667]
    Mean score of shuffle split cross validation:0.95
    

    我的博客即将搬运同步至腾讯云+社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=2l6rvdr2fmcko

  • 相关阅读:
    VS2008编写MFC程序--使用opencv2.4()
    November 02nd, 2017 Week 44th Thursday
    November 01st, 2017 Week 44th Wednesday
    October 31st, 2017 Week 44th Tuesday
    October 30th, 2017 Week 44th Monday
    October 29th, 2017 Week 44th Sunday
    October 28th, 2017 Week 43rd Saturday
    October 27th, 2017 Week 43rd Friday
    October 26th, 2017 Week 43rd Thursday
    October 25th, 2017 Week 43rd Wednesday
  • 原文地址:https://www.cnblogs.com/ysugyl/p/8707887.html
Copyright © 2011-2022 走看看