zoukankan      html  css  js  c++  java
  • 学习笔记26— roc曲线(python)

    一、概念:

    准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure

    机器学习(ML), 自然语言处理(NLP), 信息检索(IR)等领域, 评估(Evaluation)是一个必要的工作, 而其评价指标往往有如下几点: 准确率(Accuracy), 精确率(Precision), 召回率(Recall) 和 F1-Measure.(注:相对来说,IR 的 ground truth 很多时候是一个 Ordered List, 而不是一个 Bool 类型的 Unordered Collection,在都找到的情况下,排在第三名还是第四名损失并不是很大,而排在第一名和第一百名,虽然都是“找到了”,但是意义是不一样的,因此更多可能适用于 MAP 之类评估指标.)

    本文将简单介绍其中几个概念. 中文中这几个评价指标翻译各有不同, 所以一般情况下推荐使用英文.

    现在我先假定一个具体场景作为例子.

    假如某个班级有男生 80 人, 女生20人, 共计 100 人. 目标是找出所有女生. 现在某人挑选出 50 个人, 其中 20 人是女生, 另外还错误的把 30 个男生也当作女生挑选出来了. 作为评估者的你需要来评估(evaluation)下他的工作

    首先我们可以计算准确率(accuracy), 其定义是: 对于给定的测试数据集,分类器正确分类的样本数与总样本数之比. 也就是损失函数是0-1损失时测试数据集上的准确率[1].

    这样说听起来有点抽象,简单说就是,前面的场景中,实际情况是那个班级有男的和女的两类,某人(也就是定义中所说的分类器)他又把班级中的人分为男女两类. accuracy 需要得到的是此君分正确的人总人数的比例. 很容易,我们可以得到:他把其中70(20女+50男)人判定正确了, 而总人数是100人,所以它的 accuracy 就是70 %(70 / 100).

    由准确率,我们的确可以在一些场合,从某种意义上得到一个分类器是否有效,但它并不总是能有效的评价一个分类器的工作. 举个例子, google 抓取了 argcv 100个页面,而它索引中共有10,000,000个页面, 随机抽一个页面,分类下, 这是不是 argcv 的页面呢?如果以 accuracy 来判断我的工作,那我会把所有的页面都判断为"不是 argcv 的页面", 因为我这样效率非常高(return false, 一句话), 而 accuracy 已经到了99.999%(9,999,900/10,000,000), 完爆其它很多分类器辛辛苦苦算的值, 而我这个算法显然不是需求期待的, 那怎么解决呢?这就是 precision, recall 和 f1-measure 出场的时间了.

    再说 precision, recall 和 f1-measure 之前, 我们需要先需要定义 TP, FN, FP, TN 四种分类情况.

    按照前面例子, 我们需要从一个班级中的人中寻找所有女生, 如果把这个任务当成一个分类器的话, 那么女生就是我们需要的, 而男生不是, 所以我们称女生为"正类", 而男生为"负类".

     相关(Relevant), 正类无关(NonRelevant), 负类
    被检索到(Retrieved) true positives (TP 正类判定为正类, 例子中就是正确的判定"这位是女生") false positives (FP 负类判定为正类,"存伪", 例子中就是分明是男生却判断为女生, 当下伪娘横行, 这个错常有人犯)
    未被检索到(Not Retrieved) false negatives (FN 正类判定为负类,"去真", 例子中就是, 分明是女生, 这哥们却判断为男生--梁山伯同学犯的错就是这个) true negatives (TN 负类判定为负类, 也就是一个男生被判断为男生, 像我这样的纯爷们一准儿就会在此处)

    或者

    可以很容易看出, 所谓 TRUE/FALSE 表示从结果是否分对了, Positive/Negative 表示我们认为的是"是"还是"不是".

    通过这张表, 我们可以很容易得到这几个值:

    • TP=20
    • FP=30
    • FN=0
    • TN=50

    精确率(precision)的公式是P=TPTP+FPP = frac{TP}{TP+FP}P=TP+FPTP, 它计算的是所有"正确被检索的结果(TP)"占所有"实际被检索到的(TP+FP)"的比例.

    在例子中就是希望知道此君得到的所有人中, 正确的人(也就是女生)占有的比例. 所以其 precision 也就是40%(20女生/(20女生+30误判为女生的男生)).

    召回率(recall)的公式是R=TPTP+FNR = frac{TP}{TP+FN}R=TP+FNTP, 它计算的是所有"正确被检索的结果(TP)"占所有"应该检索到的结果(TP+FN)"的比例.

    在例子中就是希望知道此君得到的女生占本班中所有女生的比例, 所以其 recall 也就是100%(20女生/(20女生+ 0 误判为男生的女生))

    F1值就是精确值和召回率的调和均值, 也就是

    2F1=1P+1Rfrac{2}{F_1} = frac{1}{P} + frac{1}{R} F12=P1+R1

    调整下也就是

    F1=2PRP+R=2TP2TP+FP+FNF_1 = frac{2PR}{P+R} = frac{2TP}{2TP + FP + FN} F1=P+R2PR=2TP+FP+FN2TP

    例子中 F1-measure 也就是约为 57.143%(2∗0.4∗10.4+1frac{2 * 0.4 * 1}{0.4 + 1}0.4+120.41).

    需要说明的是, 有人[3]列了这样个公式, 对非负实数βetaβ

    Fβ=(β2+1)∗PRβ2∗P+RF_eta = (eta^2 + 1) * frac{PR}{eta^2*P+R}Fβ=(β2+1)β2P+RPR

    将 F-measure 一般化.

    F1-measure 认为精确率和召回率的权重是一样的, 但有些场景下, 我们可能认为精确率会更加重要, 调整参数 βetaβ , 使用 Fβ_etaβ-measure 可以帮助我们更好的 evaluate 结果.

     

    注意:召回率就是tpr(TP/(TP+FN)) (命中率),fpr(FP/(FP+TN)):特异性(specificity),TNR(TN/(FP+TN)):敏感性

    参考链接:http://mlwiki.org/index.php/ROC_Analysis#ROC_Analysis

                     https://blog.argcv.com/articles/1036.c

                    http://www.cnblogs.com/pinard/p/5993450.html

    1、简单(源码):

    print(__doc__)

    import numpy as np
    import matplotlib.pyplot as plt
    from itertools import cycle


    from sklearn import svm, datasets
    from sklearn.metrics import roc_curve, auc
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import label_binarize
    from sklearn.multiclass import OneVsRestClassifier
    from scipy import interp

    # Import some data to play with
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target

    # Binarize the output
    y = label_binarize(y, classes=[0, 1, 2])
    n_classes = y.shape[1]

    # Add noisy features to make the problem harder
    random_state = np.random.RandomState(0)
    n_samples, n_features = X.shape
    X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]

    # shuffle and split training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
                                                        random_state=0)

    # Learn to predict each class against the other
    classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
                                     random_state=random_state))
    y_score = classifier.fit(X_train, y_train).decision_function(X_test)

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    #Plot of a ROC curve for a specific class

    plt.figure()
    lw = 2
    plt.plot(fpr[2], tpr[2], color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()
    2、复杂(源码):

    print(__doc__)
    
    import numpy as np
    from scipy import interp
    import matplotlib.pyplot as plt
    from itertools import cycle
    
    from sklearn import svm, datasets
    from sklearn.metrics import roc_curve, auc
    from sklearn.model_selection import StratifiedKFold
    
    # #############################################################################
    # Data IO and generation
    
    # Import some data to play with
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    X, y = X[y != 2], y[y != 2]
    n_samples, n_features = X.shape
    
    # Add noisy features
    random_state = np.random.RandomState(0)
    X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
    
    # #############################################################################
    # Classification and ROC analysis
    
    # Run classifier with cross-validation and plot ROC curves
    cv = StratifiedKFold(n_splits=6)
    classifier = svm.SVC(kernel='linear', probability=True,
                         random_state=random_state)
    
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)
    
    i = 0
    for train, test in cv.split(X, y):
        probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
        # Compute ROC curve and area the curve
        fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
        tprs.append(interp(mean_fpr, fpr, tpr))
        tprs[-1][0] = 0.0
        roc_auc = auc(fpr, tpr)
        aucs.append(roc_auc)
        plt.plot(fpr, tpr, lw=1, alpha=0.3,
                 label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))
    
        i += 1
    plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
             label='Chance', alpha=.8)
    
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    plt.plot(mean_fpr, mean_tpr, color='b',
             label=r'Mean ROC (AUC = %0.2f $pm$ %0.2f)' % (mean_auc, std_auc),
             lw=2, alpha=.8)
    
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                     label=r'$pm$ 1 std. dev.')
    
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()

    Best threshold:

      • we know the slope of the accuracy line: it's 1
      • the best classifier for this slope is the 6th one
      • roc-curve-ex1.png
      • threshold value θ
      • so we take the score obtained on the 6th record
      • and use it as the threshold value θ
    • i.e. predict positive if θ0.54
    • if we check, we see that indeed we have accuracy = 0.7
    
    
  • 相关阅读:
    HBASE列族不能太多的真相 (一个table有几个列族就有几个 Store)
    Linux虚拟机添加新硬盘的全程图解
    Servlet 单例多线程
    MapReduce类型与格式(输入与输出)
    hbase集群的启动,注意几个问题
    spring 的IoC的个人理解
    深入Java核心 Java中多态的实现机制(1)
    spring mvc 请求转发和重定向(转)
    XML中<beans>中属性概述
    hadoop+javaWeb的开发中遇到包冲突问题(java.lang.VerifyError)
  • 原文地址:https://www.cnblogs.com/hechangchun/p/9840492.html
Copyright © 2011-2022 走看看