https://blog.csdn.net/cxx654/article/details/104813830
sklearn中cross_val_score、cross_val_predict的用法比较
程大海 2020-03-12 11:02:36 8444 收藏 21
分类专栏: python编程 机器学习 文章标签: 机器学习 sklearn cross_val_score 交叉验证
版权
python编程
同时被 2 个专栏收录
49 篇文章0 订阅
订阅专栏
机器学习
33 篇文章0 订阅
订阅专栏
交叉验证的概念,直接粘贴scikit-learn官网的定义:
scikit-learn中计算交叉验证的函数:
cross_val_score:得到K折验证中每一折的得分,K个得分取平均值就是模型的平均性能
cross_val_predict:得到经过K折交叉验证计算得到的每个训练验证的输出预测
方法:
cross_val_score:分别在K-1折上训练模型,在余下的1折上验证模型,并保存余下1折中的预测得分
cross_val_predict:分别在K-1上训练模型,在余下的1折上验证模型,并将余下1折中样本的预测输出作为最终输出结果的一部分
结论:
cross_val_score计算得到的平均性能可以作为模型的泛化性能参考
cross_val_predict计算得到的样本预测输出不能作为模型的泛化性能参考
from sklearn import datasets import numpy as np from sklearn.tree import DecisionTreeClassifier from sklearn import datasets import numpy as np from sklearn.tree import DecisionTreeClassifier # 加载鸢尾花数据集 iris = datasets.load_iris() iris_train = iris.data iris_target = iris.target print(iris_train.shape) print(iris_target.shape) (150, 4) (150,) # 构建决策树分类模型 tree_clf = DecisionTreeClassifier() tree_clf.fit(iris_train, iris_target) tree_predict = tree_clf.predict(iris_train) # 计算决策树分类模型的准确率 from sklearn.metrics import accuracy_score print("Accuracy:", accuracy_score(iris_target, tree_predict)) Accuracy: 1.0 # 交叉验证cross_val_score输出每一折上的准确率 from sklearn.model_selection import cross_val_predict, cross_val_score, cross_validate tree_scores = cross_val_score(tree_clf, iris_train, iris_target, cv=3) print(tree_scores) [0.98039216 0.92156863 1. ] # 交叉验证cross_val_predict输出每个样本的预测结果 tree_predict = cross_val_predict(tree_clf, iris_train, iris_target, cv=3) print(tree_predict) print(len(tree_predict)) print(accuracy_score(iris_target, tree_predict)) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2] 150 0.96 print(tree_clf.predict(iris_train)) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] # 交叉验证cross_validate对cross_val_score结果进行包装,并包含fit的时间等信息 tree_val = cross_validate(tree_clf, iris_train, iris_target, cv=3) print(tree_val) {'fit_time': array([0., 0., 0.]), 'score_time': array([0., 0., 0.]), 'test_score': array([0.98039216, 0.92156863, 0.97916667])}