zoukankan      html  css  js  c++  java
  • 评估指标【交叉验证&ROC曲线】

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Mon Sep 10 11:21:27 2018
     4 
     5 @author: zhen
     6 """
     7 from sklearn.datasets import fetch_mldata
     8 import numpy as np
     9 from sklearn.linear_model import SGDClassifier
    10 from sklearn.model_selection import cross_val_score
    11 from sklearn.model_selection import cross_val_predict
    12 from sklearn.metrics import precision_recall_curve
    13 import matplotlib
    14 import matplotlib.pyplot as plt
    15 from sklearn.metrics import roc_curve
    16 from sklearn.metrics import roc_auc_score
    17 from sklearn.ensemble import RandomForestClassifier
    18 
    19 mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home')
    20 
    21 x, y = mnist['data'], mnist['target']
    22 some_digit = x[36000]  #获取第36000行数据
    23 
    24 some_digit_image = some_digit.reshape(28, 28)
    25 
    26 plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,
    27            interpolation='nearest', vmin=0, vmax=1)
    28 plt.axis('off')
    29 plt.show()
    30 
    31 x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
    32 shuffle_index = np.random.permutation(60000)
    33 
    34 x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
    35 
    36 y_train_5 = (y_train == 5)
    37 y_test_5 = (y_test == 5)
    38 
    39 sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4)
    40 sgd_clf.fit(x_train, y_train_5)  
    41 
    42 result = sgd_clf.predict([some_digit])
    43 
    44 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy'))
    45 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision'))
    46 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall'))
    47 
    48 sgd_clf.fit(x_train, y_train_5)
    49 
    50 y_scores = sgd_clf.decision_function([some_digit])
    51 
    52 threshold = 0
    53 y_some_digit_pred = (y_scores > threshold)
    54 
    55 threshold = 200000
    56 y_some_digit_pred = (y_scores > threshold)
    57 
    58 # cv 数据集划分的个数
    59 y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function')
    60 
    61 precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
    62 
    63 
    64 def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    65     plt.plot(thresholds, precisions[:-1], 'b--',label='Precision')
    66     plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')
    67     plt.xlabel("Threshold")
    68     plt.legend(loc='upper left')
    69     plt.ylim([0, 1])
    70     plt.show()  
    71     
    72     
    73 def plot_roc_curve(fpr, tpr, label=None):
    74     plt.plot(fpr, tpr, linewidth=2, label='roc')
    75     plt.plot([0, 1], [0, 1], 'k--', label='mid')
    76     plt.legend(loc='lower right')
    77     # plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度
    78     plt.xlabel('fpr')
    79     plt.ylabel('tpr')
    80     plt.show()  
    81 
    82 
    83 plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
    84 
    85 fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
    86 plot_roc_curve(fpr, tpr)
    87    
    88 print(roc_auc_score(y_train_5, y_scores))
    89 
    90 forest_clf = RandomForestClassifier(random_state=42)
    91 y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba')
    92 y_scores_forest = y_probas_forest[:, 1]
    93 fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
    94 plt.plot(fpr, tpr, 'b:', label='SGD')
    95 plt.plot(fpr_forest, tpr_forest, label='Random Forest')
    96 plt.legend(loc='lower right')
    97 plt.show()
    98 
    99 print(roc_auc_score(y_train_5, y_scores_forest))

              

    总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!

  • 相关阅读:
    HashMap与ArrayList的相互嵌套
    Mysql与Oracle 的对比
    什么是子查询
    创建存储过程
    cmd 快捷键
    navicat 快捷键
    Mysql的数据类型 6种
    Mysql与Oracle 的使用区别
    怎样修复ie浏览器
    Linux官方内置Bash中新发现一个非常严重安全漏洞
  • 原文地址:https://www.cnblogs.com/yszd/p/9620516.html
Copyright © 2011-2022 走看看