zoukankan      html  css  js  c++  java
  • sklearn中cross_val_score、cross_val_predict的用法比较

    https://blog.csdn.net/cxx654/article/details/104813830

    sklearn中cross_val_score、cross_val_predict的用法比较

    程大海 2020-03-12 11:02:36 8444 收藏 21
    分类专栏: python编程 机器学习 文章标签: 机器学习 sklearn cross_val_score 交叉验证
    版权

    python编程
    同时被 2 个专栏收录
    49 篇文章0 订阅
    订阅专栏

    机器学习
    33 篇文章0 订阅
    订阅专栏
    交叉验证的概念,直接粘贴scikit-learn官网的定义:

    scikit-learn中计算交叉验证的函数:

    cross_val_score:得到K折验证中每一折的得分,K个得分取平均值就是模型的平均性能

    cross_val_predict:得到经过K折交叉验证计算得到的每个训练验证的输出预测

    方法:

    cross_val_score:分别在K-1折上训练模型,在余下的1折上验证模型,并保存余下1折中的预测得分

    cross_val_predict:分别在K-1上训练模型,在余下的1折上验证模型,并将余下1折中样本的预测输出作为最终输出结果的一部分

    结论:

    cross_val_score计算得到的平均性能可以作为模型的泛化性能参考

    cross_val_predict计算得到的样本预测输出不能作为模型的泛化性能参考

    from sklearn import datasets
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
    from sklearn import datasets
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
     
    # 加载鸢尾花数据集
    iris = datasets.load_iris()
    iris_train = iris.data
    iris_target = iris.target
    print(iris_train.shape)
    print(iris_target.shape)
    (150, 4)
    (150,)
     
    # 构建决策树分类模型
    tree_clf = DecisionTreeClassifier()
    tree_clf.fit(iris_train, iris_target)
    tree_predict = tree_clf.predict(iris_train)
    ​
    # 计算决策树分类模型的准确率
    from sklearn.metrics import accuracy_score
    print("Accuracy:", accuracy_score(iris_target, tree_predict))
    Accuracy: 1.0
     
    # 交叉验证cross_val_score输出每一折上的准确率
    from sklearn.model_selection import cross_val_predict, cross_val_score, cross_validate
    tree_scores = cross_val_score(tree_clf, iris_train, iris_target, cv=3)
    print(tree_scores)
    [0.98039216 0.92156863 1.        ]
     
    # 交叉验证cross_val_predict输出每个样本的预测结果
    tree_predict = cross_val_predict(tree_clf, iris_train, iris_target, cv=3)
    print(tree_predict)
    print(len(tree_predict))
    print(accuracy_score(iris_target, tree_predict))
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1
     1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
     2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2
     2 2]
    150
    0.96
     
    print(tree_clf.predict(iris_train))
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
     1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
     2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
     2 2]
     
    # 交叉验证cross_validate对cross_val_score结果进行包装,并包含fit的时间等信息
    tree_val = cross_validate(tree_clf, iris_train, iris_target, cv=3)
    print(tree_val)
    {'fit_time': array([0., 0., 0.]), 'score_time': array([0., 0., 0.]), 'test_score': array([0.98039216, 0.92156863, 0.97916667])}
     
    ​
     
    ​
  • 相关阅读:
    PHP基本的语法以及和Java的差别
    Linux 性能測试工具
    【Oracle 集群】Linux下Oracle RAC集群搭建之Oracle DataBase安装(八)
    【Oracle 集群】Oracle 11G RAC教程之集群安装(七)
    【Oracle 集群】11G RAC 知识图文详细教程之RAC在LINUX上使用NFS安装前准备(六)
    【Oracle 集群】ORACLE DATABASE 11G RAC 知识图文详细教程之RAC 特殊问题和实战经验(五)
    【Oracle 集群】ORACLE DATABASE 11G RAC 知识图文详细教程之缓存融合技术和主要后台进程(四)
    【Oracle 集群】ORACLE DATABASE 11G RAC 知识图文详细教程之RAC 工作原理和相关组件(三)
    Oracle 集群】ORACLE DATABASE 11G RAC 知识图文详细教程之ORACLE集群概念和原理(二)
    【Oracle 集群】ORACLE DATABASE 11G RAC 知识图文详细教程之集群概念介绍(一)
  • 原文地址:https://www.cnblogs.com/carl2380/p/15249532.html
Copyright © 2011-2022 走看看