zoukankan      html  css  js  c++  java
  • cross-validation

    交叉验证

    • 模型拟合的程度好坏取决于数据的划分(主要指训练集和测试集的划分)-
    • 不代表模型的泛化能力datacamp

    Cross-validation is a vital step in evaluating a model. It maximizes the amount of data that is used to train the model, as during the course of training, the model is not only trained, but also tested on all of the available data.
    最大化的选择模型分训练集,使其泛化能力更好
    换句话说,就是选择最佳参数

    基本思想

    交叉验证的基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train set),另一部分做为验证集(validation set or test set),首先用训练集对分类器进行训练,再利用验证集来测试训练得到的模型(model),以此来做为评价分类器的性能指标。百度百科

    用途

    • 准确的调整模型的超参数(Hyperparameter),且这组参数对不同的数据,表现相对稳定
    • 在某些分类场景,你可以同时使用逻辑回归、决策树或聚类等多种算法建模,当不确定哪种算法效果更好时,可以使用交叉验证

    例子

    • 为了降低测试数据产生的偶然性,更好的做法便是采用「交叉验证」,还是以切分 5 份数据为例,交叉验证的做法是,对于同一个算法,同时训练出 5 个模型,每个模型采用不同的测试数据(例如模型 1 选用第 1 份,模型 2 选用第 2 份,以此类推),在所有模型都完成测试后,再对这 5 个模型的评估结果求平均,便可以得到一个相对稳定且更有说服力的算法。

    • 举个具体的例子,假设我们的模型采用决策树算法,该算法有个超参数是树的深度 height,我们可以将其设为 2,也可以设为 3,但不清楚设哪个数比较好,此时我们就可以使用「交叉验证」来帮我们决策,首先还是将数据 5等分,对每一个参数值,我们都训练 5次,输出 5种可能的测试结果,然后对这5个结果取平均,即可测试出哪个是我们想要的结果。简书

    可用API

    sklearn.model_selection

    5折交叉验证

    # Import the necessary modules
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import cross_val_score
    
    # Create a linear regression object: reg
    reg = LinearRegression()
    
    # Compute 5-fold cross-validation scores: cv_scores
    cv_scores = cross_val_score(reg, X, y, cv=5)
    
    # Print the 5-fold cross-validation scores
    print(cv_scores)
    
    # Print the average 5-fold cross-validation score
    print("Average 5-Fold CV Score: {}".format(np.mean(cv_scores)))
    

    K折交叉验证

    # Import necessary modules
    from sklearn.linear_model import LinearRegression
    from sklearn.model_selection import cross_val_score
    
    
    # Create a linear regression object: reg
    reg = LinearRegression()
    
    # Perform 3-fold CV
    cvscores_3 = cross_val_score(reg, X, y, cv = 3)
    print(np.mean(cvscores_3))
    
    # Perform 10-fold CV
    cvscores_10 = cross_val_score(reg, X, y, cv = 10)
    print(np.mean(cvscores_10))
    
    <script.py> output:
        0.8718712782622108
        0.8436128620131201
    
  • 相关阅读:
    轻松构建微服务之分布式事物
    Java和操作系统交互细节
    网络内核之TCP是如何发送和接收消息的
    最全的微服务知识科普
    【译】TCP Implementation in Linux
    轻松构建微服务之分库分表
    OSX的一些基本知识
    Delphi中使用比较少的一些语法
    热烈祝贺我的博客开通
    Windows 服务快捷启动命令
  • 原文地址:https://www.cnblogs.com/gaowenxingxing/p/12302890.html
Copyright © 2011-2022 走看看