zoukankan      html  css  js  c++  java
  • 理解accuracy/precision_score、micro/macro

    • 1、accuracy即我们通常理解的准确率,计算的时候是指在预测值pred与目标值target之间重叠的部分的大小除以pred的大小(或target的大小,因为sklearn要求pred与target大小必须一致)。

    比如

    target=[2,2,2,3,2]
    pred=[2,2,1,3,4]

    此时重叠的部分为pred[0]、pred[1]、pred[3],accuracy=3/5=0.6

    而precision_score意思类似,但不完全一样,它是从pred的角度考虑,意思是我虽然预测了这么多个,但有几个预测对了?占比多少?这就是precision。
    举例:

    target=np.array([1,1,1,1,1])
    pred=np.array([2,2,1,1,1])
    p=precision_score(target,pred)
    print(p)#1.0

    的结果为1.0(看不懂?马上解释);
    但如果是accuracy_score(target,pred),就会输出0.60(预测了5个,对了3个)  此外,

    p=precision_score(target,pred)

     其实等价于

    p=precision_score(target,pred,average='binary',pos_label=1)

    average='binary'表示pred当中(最多)只有两种标签(此处是1,2),pos_label=1表示最后输出的结果是针对类别值"1"的计算结果。

    如果改成pos_label=2

    p=precision_score(target,pred,average='binary',pos_label=2)

    则表示计算结果是针对类别为"2"的统计结果,结果为0.0(因为pred中有两个2,但都预测错了,所以为0)。

    pos_label这个参数只有average='binary'时管用,若pred中出现3种及以上类别的标签,则pos_label参数即使设置了也会被忽略。

    • 2、下面说说average参数的作用,当pred当中有3种或以上类别时,average的值只能取[None, 'micro', 'macro', 'weighted']当中的一种,

    其中,有趣的是,当average='micro'时,precision_score(target,pred,average='micro')等价于accuracy的算法,就是看预测对的标签总个数,再除以pred的大小。

    而当average='macro'时,会计算每个种类标签的precision_score,再取平均(不考虑各个类别的样本分布差异)。例如:

    target=np.array([1,1,2,3,3])
    pred=np.array([2,2,1,3,4])
    precision_score(target,pred,average='macro')

    的结果为0.25,怎么算出来的0.25?

    这样算:首先确定,pred中一共是4个类别(1,2,3,4),预测对的只有pred[3],

    类别1的precision=0/1(pred中只有1个标签1,预测对了0个)结果为0
    类别2的precision=0/2(pred中有2个标签2,预测对了0个),结果为0
    类别3的precision=1/1(pred中有1个标签3,预测对了1个),结果为1
    类别4的precision=0/1(pred中有1个标签4,预测对了0个),结果为0

    再取平均(不考虑类别的分布差异):(0+0+1+0)/4=0.25。

    Macro另外一个需要注意的地方在于,它在做平均的时候,它认为的类别数不是看target有多少个类别,也不是看pred有多少个类别,而是两者取并集,作为总的类别数。

    例如,

    target=np.array([1,1,3,3,4])
    pred = np.array([1,1,2,0,0,])
    print(precision_score(target,pred,average='macro'))#输出0.2

    输出0.2,你敢信?并且报了一个警告:

    D:Anaconda3libsite-packagessklearnmetrics\_classification.py:1221: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
      _warn_prf(average, modifier, msg_start, len(result))

    说那些没有被预测的类别的precision_score默认为0,

    于是,就是类别1的预测precision_score=1,但一共有0,1,2,3,4,共5类别,而其余类别的precison均为0,所以1/5=0.2。

    而当average='weighted'时,前面的步骤跟macro相同,分别计算pred中各个类别的precision_score,但汇总时则会考虑各个类别的样本的分布差异。例如,

    target=np.array([1,1,2,2,3,3])
    pred=np.array([2,2,1,3,4,3])
    precision_score(target,pred,average='macro')

    输出0.125

    计算过程:pred中有4个类别,类别1、2、4预测正确的个数为0,precision_score均为0,而类别3,预测了2个(pred[3]、pred[5]),仅对了1个(pred[5]),precision_score=0.5,再取平均(0+0+0.5+0)/4=0.125,而

    precision_score(target,pred,average='weighted')

     则输出0.1666

    因为average='weighted'时,会考虑各个类别(在target中,而非pred中)的分布比例,比如,类别1在target中出现占比为2/6,同理类别3的占比为2/6=1/3,
    因此最后结果等于0*1/3+0*1/3+0.5*1/3+0*1/3=1/6=0.1666

    • 3、再讲讲召回率recall_score的计算方法。

    举例:

    target=np.array([1,2,2,2,3,3])
    pred=np.array([2,2,1,3,4,3])
    recall_score(target,pred,average='macro')

    输出0.20833333333333331,其实就是5/24。怎么来的?

    首先依据定义,pred的召回率是指pred中出现的各类标签,预测正确的次数占target中该类标签数的比例。

    类别1的recall=0/1=0
    类别2的recall=1/3(1个,即pred[1]预测正确,target中类别2出现了3次)
    类别3的recall=1/2(1个,即pred[5]预测正确,target中类别3出现2次)
    类别4的recall=0,类别4在target中没有出现,因此预测错误。

    最后求平均:(1/3+1/2)/4=5/24=0.20833333333333334

    同理,average='weighted'时,

    recall_score(np.array([1,2,2,2,3,3]),np.array([2,2,1,3,4,3]),average='weighted')

    输出0.333333=1/3

    即分别用标签1、2、3、4的recall_score再乘上它们再target中的占比,0*1/6+1/3*3/6+1/2*2/6+0*0/6=1/3。

    • 4、另外,需要注意的是sklearn.metrics模块中的precision_score,recall_score中对比的target和pred参数都是一维数组(若不是一维数组,则会认为某一行代表一个样本有多个标签),因此在模型的批训练时,获得的预测结果,需要先转换成一维数组再计算precision,recall,f1等指标。

    参考:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score

  • 相关阅读:
    【根据条件添加属性】vue页面标签根据条件添加属性
    serialVersionUID
    onsubmit="return navTabSearch(this);"
    MyEclipse改变项目的编码方式
    Tomcat端口被占用
    可拖动图层
    顶部可以折叠的菜单工具栏
    转---- javascript prototype介绍的文章
    网页右侧弹出有缓冲效果的工具栏
    根据时间改变背景
  • 原文地址:https://www.cnblogs.com/aaronhoo/p/13865317.html
Copyright © 2011-2022 走看看