zoukankan      html  css  js  c++  java
  • Multilabel(多标签分类)metrics:hamming loss,F score

    最近在做一个multilabel classification(多标签分类)的项目,需要一些特定的metrics去评判一个multilabel classifier的优劣。这里对用到的三个metrics做一个总结。

    首先明确一下多标签(multilabel)分类和多类别(multiclass)分类的不同:multiclass仅仅表示输出的类别大于2个,这样可以和一般的二分类(binary)区别开,但每一个输入x仅仅对应一个输出标签。而multilabel里的每一个输入可以对应多个输出标签。举个例子,对动物图片进行分类,一个输入(某一动物)只可能有一个输出(猫or狗),因为一个动物不可能同时是猫还有狗,这就是multiclass;但对新闻进行分类,一个输入(某一新闻)可以对应多个输出(海外,体育),这就是multilabel。

    这次用到的数据集是MS-COCO 2014年的数据集,它的testset里有40504张图片,因为MS-COCO主要是用来做目标检测(detection)或者物体分割的(segmentation),每张图片里包含了多个物体。我们就用这些物体为图片的对应标签来做multilabel分类。一共有80个可能出现在图片里的物体,所以每张图片对应一个(80x1)的标签向量,被分到类的物体对应的标量为1,没有分到类的为0(只要图片中出现了某物体,此图片就被分到对应的物体类里,然后一张图片会有多个类别因为大多图片都包含了几个物体)。

    N表示testset的大小,L表示所有备选label合集的大小,在MS-COCO2014的testset里,(M=40504,L=80)

    Hamming Loss

    (h_{loss})可能是最直观也是最容易理解的一个loss,它直接统计了被误分类label的个数(不属于这个样本的标签被预测,或者属于这个样本的标签没有被预测)。(h_{loss}=0)表示所有的每一个data的所有label都被分对了。

    (h_{loss}= frac{1}{p}∑_{i=1}^ph(x_{i})ΔY_{i}), where (p=N imes L)

    (F_{1}) score

    介绍F score之前首先要理清准确率(Accuracy),精确率(Precision)和召回率(Recall)之间的关系。

    Accuracy的定义是:分类器正确分类的次数与总分类数之比。Hamming Loss可以说是accuracy的一种呈现。但其实如果只追求hamming loss/accuracy的话,就会出现以下问题:已知MS-COCO里的大多数图片只包含几个(<=4)物体,这样只要把所有图片的标签都设为0,也能达到大约4/80(0.08)的hamming loss。

    所以我们需要precision(查准率)和recall(查全率)作为别的metrics去评判一个classifier的给力程度。下面解释了如何用precision和recall,还有这两个metrics的合成F score来评判一个multilabel classifier。

    precision和recall一开始是从信息检索衍生出来的。precision(查准率)计算的是所有"正确被检索的item(TP)"占所有"实际被检索到的(TP+FP)"的比例,recall(查全率)计算的是所有"正确被检索的item(TP)"占所有"应该检索到的item(TP+FN)"的比例。后来也可以延伸去二分类问题,如下表所示。

    gtpred 1 0
    1 TP FP
    0 FN TN

    precision:(P = frac{TP}{TP+FP})

    recall:(R = frac{TP}{TP+FN})

    (F_{1})值就是precision和 recall的调和均值(harmonic mean):(frac{2}{F_1} = frac{1}{P} + frac{1}{R}),也能写成(F_1 = frac{2PR}{P+R} = frac{2TP}{2TP + FP + FN})

    Scikit-learn的官方文档中把precision,recall的关系解释得十分清楚(https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html#sphx-glr-auto-examples-model-selection-plot-precision-recall-py

    实际上(F_{1})(F_{eta})的一个特殊形式,在一些应用中,对查准率和查全率的重视程度不同。比如在商品推荐中,要尽可能少的打扰用户,所以查准率更重要;在疾病筛查中,要尽可能漏掉可能疾病,所以查全率更重要。

    (F_{eta}= frac{(eta^2 + 1 )PR}{eta^2P+R}),从左边式子可以看出(eta>1)时查全率更重要,(eta<1)时查准率更重要,(eta=1)时查全和查准同样重要。

    (F_{1}) score in Multilabel Classification:

    在针对multilabel分类计算F score的时候,通常有macro和micro两种average的方法。Python的scikit-learn库在计算f1 score也提供了micro和macro两种选择,具体在multilabel的情况下,怎么计算(F_{1}) score,在网上查阅了很多博客和资料都没有给出一个明确的用列子解释的步骤,这边我自己通过整合资料代码验证出了macro和micro两种(F_1)score的计算方法。

    请看下面的简单例子:

    (,N=3, L=3) ,一共有三个data,每个data有三个预备分类:

    import numpy as np
    y_gt = np.array ([[1,0,1],[0,1,1],[0,1,0]])
    y_pred = np.array ([[0,0,1],[1,1,1],[1,1,1]])
    from sklearn.metrics import f1_score
    f1_score(y_gt, y_pred, average="macro") #0.6
    f1_score(y_gt, y_pred, average="micro") #0.66666666666666666
    

    对于每一个class,我们都需要先算一个2x2的confusion matrix,里面分别标明了对于这一个class的TP,FP,FN和TN。注意,这个confusion matrix不是multiclass里的那个,而是类似于binary classifier里,只有0和1分类的confusion matrix。

    macro:

    对于macro,我们通过每一个class的confusion matrix算出它的precision和recall,并计算出对与那个class的F1 score,最后通过平均所有class的(F_1) score得到(F1_{macro})

    Class 0:

    gtpred 1 0
    1 0 1
    0 2 0

    (P_0 = frac{0}{0+2}=0, R_0 = frac{0}{0+1}=0)

    (F1_0=frac{2 imes0 imes0}{0+0}=0)

    Class 1:

    gtpred 1 0
    1 2 0
    0 0 1

    (P_1 = frac{2}{2+0}=1, R_1 = frac{2}{2+0}=1)

    (F1_1=frac{2 imes1 imes1}{1+1}=1)

    Class 2:**

    gtpred 1 0
    1 2 0
    0 1 0

    (P_2 = frac{2}{2+1}=frac{2}{3}, R_2 = frac{2}{2+0}=1)

    (F1_2=frac{2 imes{frac{2}{3}} imes1}{frac{2}{3}+1}=frac{4}{5})

    (F1_{macro}=frac{1}{N}(F1_0+ F1_1+…+F1_N)=frac{1}{3}(F1_0+F1_1+F1_2)=frac{1}{3}(0+1+frac{4}{5})=0.6)

    micro:

    对于micro,我们把所有class的binary confusion matrix整合成一个大的2x2confusion matrix,然后并对于整合成的confusion matrix算出一个precision和recall值((P_{comb}) and (R_{comb})),最后通过公式得到(F1_{micro}) score。

    Combined all classes:

    gtpred 1 0
    1 4 1
    0 3 1

    (P_{comb}=frac{4}{4+3}=frac{4}{7}, R_{comb}=frac{4}{4+1}=frac{4}{5})

    (F1_{micro}=frac{2P_{comb}R_{comb}}{P_{comb}+R_{comb}}=frac{2 imesfrac{4}{7} imesfrac{4}{5}}{frac{4}{7}+frac{4}{5}}=0.66666666...)

    Mean Average Precision:

  • 相关阅读:
    使用shc加密bash脚本程序
    shell加密工具shc的安装和使用
    cgi程序报 Premature end of script headers:
    gearmand安装过程
    解决Gearman 报sqlite3错误
    gearman安装实录
    PHP APC安装与使用
    在Centos上面用yum不能安装redis的朋友看过来
    CentOS 5
    CentOS安装配置MongoDB
  • 原文地址:https://www.cnblogs.com/fledlingbird/p/10675922.html
Copyright © 2011-2022 走看看