zoukankan      html  css  js  c++  java
  • cross-validation

    交叉验证

    • 模型拟合的程度好坏取决于数据的划分(主要指训练集和测试集的划分)-
    • 不代表模型的泛化能力datacamp

    Cross-validation is a vital step in evaluating a model. It maximizes the amount of data that is used to train the model, as during the course of training, the model is not only trained, but also tested on all of the available data.
    最大化的选择模型分训练集,使其泛化能力更好
    换句话说,就是选择最佳参数

    基本思想

    交叉验证的基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train set),另一部分做为验证集(validation set or test set),首先用训练集对分类器进行训练,再利用验证集来测试训练得到的模型(model),以此来做为评价分类器的性能指标。百度百科

    用途

    • 准确的调整模型的超参数(Hyperparameter),且这组参数对不同的数据,表现相对稳定
    • 在某些分类场景,你可以同时使用逻辑回归、决策树或聚类等多种算法建模,当不确定哪种算法效果更好时,可以使用交叉验证

    例子

    • 为了降低测试数据产生的偶然性,更好的做法便是采用「交叉验证」,还是以切分 5 份数据为例,交叉验证的做法是,对于同一个算法,同时训练出 5 个模型,每个模型采用不同的测试数据(例如模型 1 选用第 1 份,模型 2 选用第 2 份,以此类推),在所有模型都完成测试后,再对这 5 个模型的评估结果求平均,便可以得到一个相对稳定且更有说服力的算法。

    • 举个具体的例子,假设我们的模型采用决策树算法,该算法有个超参数是树的深度 height,我们可以将其设为 2,也可以设为 3,但不清楚设哪个数比较好,此时我们就可以使用「交叉验证」来帮我们决策,首先还是将数据 5等分,对每一个参数值,我们都训练 5次,输出 5种可能的测试结果,然后对这5个结果取平均,即可测试出哪个是我们想要的结果。简书

    可用API

    sklearn.model_selection

    5折交叉验证

    # Import the necessary modules
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import cross_val_score
    
    # Create a linear regression object: reg
    reg = LinearRegression()
    
    # Compute 5-fold cross-validation scores: cv_scores
    cv_scores = cross_val_score(reg, X, y, cv=5)
    
    # Print the 5-fold cross-validation scores
    print(cv_scores)
    
    # Print the average 5-fold cross-validation score
    print("Average 5-Fold CV Score: {}".format(np.mean(cv_scores)))
    

    K折交叉验证

    # Import necessary modules
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import cross_val_score
    
    
    # Create a linear regression object: reg
    reg = LinearRegression()
    
    # Perform 3-fold CV
    cvscores_3 = cross_val_score(reg, X, y, cv = 3)
    print(np.mean(cvscores_3))
    
    # Perform 10-fold CV
    cvscores_10 = cross_val_score(reg, X, y, cv = 10)
    print(np.mean(cvscores_10))
    
    <script.py> output:
        0.8718712782622108
        0.8436128620131201
    
  • 相关阅读:
    [大数据从入门到放弃系列教程]在IDEA的Java项目里,配置并加入Scala,写出并运行scala的hello world
    [大数据从入门到放弃系列教程]第一个spark分析程序
    Mac配置Scala和Spark最详细过程
    Mac配置Hadoop最详细过程
    [从零开始搭网站八]CentOS使用yum安装Redis的方法
    CentOS磁盘用完的解决办法,以及Tomcat的server.xml里无引用,但是项目仍启动的问题
    Mysql 删除重复数据只保留id最小的
    bootstrap媒体查询常用写法
    Arduino Uno 在win7 64位下的驱动问题
    VS项目模板文件位置
  • 原文地址:https://www.cnblogs.com/gaowenxingxing/p/12302890.html
Copyright © 2011-2022 走看看