zoukankan      html  css  js  c++  java
  • sklearn中cross_val_score、cross_val_predict的用法比较

    https://blog.csdn.net/cxx654/article/details/104813830

    sklearn中cross_val_score、cross_val_predict的用法比较

    程大海 2020-03-12 11:02:36 8444 收藏 21
    分类专栏: python编程 机器学习 文章标签: 机器学习 sklearn cross_val_score 交叉验证
    版权

    python编程
    同时被 2 个专栏收录
    49 篇文章0 订阅
    订阅专栏

    机器学习
    33 篇文章0 订阅
    订阅专栏
    交叉验证的概念,直接粘贴scikit-learn官网的定义:

    scikit-learn中计算交叉验证的函数:

    cross_val_score:得到K折验证中每一折的得分,K个得分取平均值就是模型的平均性能

    cross_val_predict:得到经过K折交叉验证计算得到的每个训练验证的输出预测

    方法:

    cross_val_score:分别在K-1折上训练模型,在余下的1折上验证模型,并保存余下1折中的预测得分

    cross_val_predict:分别在K-1上训练模型,在余下的1折上验证模型,并将余下1折中样本的预测输出作为最终输出结果的一部分

    结论:

    cross_val_score计算得到的平均性能可以作为模型的泛化性能参考

    cross_val_predict计算得到的样本预测输出不能作为模型的泛化性能参考

    from sklearn import datasets
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
    from sklearn import datasets
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
     
    # 加载鸢尾花数据集
    iris = datasets.load_iris()
    iris_train = iris.data
    iris_target = iris.target
    print(iris_train.shape)
    print(iris_target.shape)
    (150, 4)
    (150,)
     
    # 构建决策树分类模型
    tree_clf = DecisionTreeClassifier()
    tree_clf.fit(iris_train, iris_target)
    tree_predict = tree_clf.predict(iris_train)
    ​
    # 计算决策树分类模型的准确率
    from sklearn.metrics import accuracy_score
    print("Accuracy:", accuracy_score(iris_target, tree_predict))
    Accuracy: 1.0
     
    # 交叉验证cross_val_score输出每一折上的准确率
    from sklearn.model_selection import cross_val_predict, cross_val_score, cross_validate
    tree_scores = cross_val_score(tree_clf, iris_train, iris_target, cv=3)
    print(tree_scores)
    [0.98039216 0.92156863 1.        ]
     
    # 交叉验证cross_val_predict输出每个样本的预测结果
    tree_predict = cross_val_predict(tree_clf, iris_train, iris_target, cv=3)
    print(tree_predict)
    print(len(tree_predict))
    print(accuracy_score(iris_target, tree_predict))
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1
     1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
     2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2
     2 2]
    150
    0.96
     
    print(tree_clf.predict(iris_train))
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
     1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
     2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
     2 2]
     
    # 交叉验证cross_validate对cross_val_score结果进行包装,并包含fit的时间等信息
    tree_val = cross_validate(tree_clf, iris_train, iris_target, cv=3)
    print(tree_val)
    {'fit_time': array([0., 0., 0.]), 'score_time': array([0., 0., 0.]), 'test_score': array([0.98039216, 0.92156863, 0.97916667])}
     
    ​
     
    ​
  • 相关阅读:
    勾股定理
    委托应用-表单的创建和编辑
    学生成绩表(输入成绩后自动算出最高、最低、平均分)
    完美拖拽(点击回放运动轨迹)
    实心图案
    微博发布
    批量删除
    数组去重的方法
    模拟垂直滚动条
    点不到的NO
  • 原文地址:https://www.cnblogs.com/carl2380/p/15249532.html
Copyright © 2011-2022 走看看