zoukankan      html  css  js  c++  java
  • mIoU混淆矩阵生成函数代码详解

    代码参考博客原文: https://blog.csdn.net/jiongnima/article/details/84750819

    在原文和原文的引用里,找到了关于mIoU详尽的解释。这里重点解析 fast_hist(a, b, n) 这个函数的代码。

    生成混淆矩阵的代码: 

    1 #设标签宽W,长H
    2 def fast_hist(a, b, n):#a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的标签,形状(H×W,);n是类别数目,实数(在这里为19)
    3     '''
    4     核心代码
    5     '''
    6     k = (a >= 0) & (a < n)#k是一个一维bool数组,形状(H×W,);目的是找出标签中需要计算的类别(去掉了背景)
    7     return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)#np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)

    在调用了 k = (a >= 0) & (a < n) 以后,得到了bool数组,那它长什么样子呢?举个栗子说明:

    构造一个4×4的数组a,把背景值设置为255,除背景外类别共3个,分别为1, 2, 3

    mushroomer@mushroomerMate:~$ python3
    Python 3.7.1rc2 (default, Jun 14 2019, 23:23:01) 
    [GCC 5.4.0 20160609] on linux
    Type "help", "copyright", "credits" or "license" for more information.
    >>> import numpy as np
    >>> a = np.array([[255, 0, 0, 255], [255, 255, 2, 2], [1, 1, 1, 255], [255, 255, 255, 255]])>>> a
    array([[255,   0,   0, 255],
           [255, 255,   2,   2],
           [  1,   1,   1, 255],
           [255, 255, 255, 255]])
    >>> n = 3
    >>> k = (a >= 0) & (a < n)
    >>> k
    array([[False,  True,  True, False],
           [False, False,  True,  True],
           [ True,  True,  True, False],
           [False, False, False, False]])
    >>> a[k]
    array([0, 0, 2, 2, 1, 1, 1])

    可以看出,k是个和a尺寸相同的bool数组,有效类别都标记为True,背景全部标记为False

    a[k] 会把 k 标记的 True 对应在 a 中的值都提取出来。

    再以 n = 3 为例,混淆矩阵如下:

    混淆矩阵映射关系:

    $index=n*class(a)+class(b)$

    之后是np.bincount, 这个函数统计下标在目标列表中出现的次数。例如:

    Python 3.7.1 (default, Dec 10 2018, 22:54:23) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
    Type "help", "copyright", "credits" or "license" for more information.
    >>> import numpy as np
    >>> np.bincount([0, 0, 0, 2, 1, 1, 3])
    array([3, 2, 1, 1], dtype=int64)
    >>> np.bincount([0, 0, 0, 2, 1, 1, 3], minlength=7)
    array([3, 2, 1, 1, 0, 0, 0], dtype=int64)
    >>> np.bincount([0, 0, 0, 2, 1, 1, 9], minlength=7)
    array([3, 2, 1, 0, 0, 0, 0, 0, 0, 1], dtype=int64)

    列表中最大值为3,统计 [0, 1, 2, 3] 对应每个元素在输入列表中出现的次数,得到 [3, 2, 1, 1], 含义是:0出现3次;1出现2次;2出现1次;3出现1次。

    如果指定 minlength, 则认为列表中最大值为 max_value = max(max([0, 0, 0, 2, 1, 1, 3]), minlength),然后去统计 list(range(max_value)) 对应每个元素在输入列表中出现的次数。

    在 fast_hist 函数中指定 minlength = n ** 2, 目的是使输出长度为 n ** 2, 输出形状就正好可以转换为 n * n 矩阵。当然根据 np.bincount 函数的特性,类别值如果超过 minlength,输出长度就不是 n ** 2 了,因此我举的栗子里背景值为 255 显然是不合适的,^_^,意识到了吗?

    然后统计出来混淆矩阵每个 index 对应的 (class a 重叠 class b) 出现的次数,就得到了结果。这里的映射关系重点是要理解每个 index 都对应唯一一个 class a 重叠 class b,例如 n = 3, class a = 1, class b = 2,那么对应的 index = 3*1 + 2 = 5,对应填到混淆矩阵里。假如 class a = 2, class b = 1, 那 index = 3*2 + 1 = 7,index 就变成了7,这个 index 是一一对应的。

  • 相关阅读:
    easyui-layout完整web界面布局
    combobox中动态载入tree数据
    easyui---layout 有无横的间隔 的区别 split:true
    单选按钮radio获取选中的值
    Datagrid清空数据
    Lancher3默认桌面显示
    菜单背景全透效果
    android通过耳机控制音乐播放器
    android 音乐暂停
    Android 系统默认音量和最大音量
  • 原文地址:https://www.cnblogs.com/adjwang/p/12194607.html
Copyright © 2011-2022 走看看