zoukankan      html  css  js  c++  java
  • gluoncv rpn 正负样本

    https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/rpn/rpn_target.py

        def forward(self, ious):
            """RPNTargetSampler is only used in data transform with no batch dimension.
            Parameters
            ----------
            ious: (N, M) i.e. (num_anchors, num_gt).
            Returns
            -------
            samples: (num_anchors,) value 1: pos, -1: neg, 0: ignore.
            matches: (num_anchors,) value [0, M).
            """
            matches = mx.nd.argmax(ious, axis=1)
    
            # samples init with 0 (ignore)
            ious_max_per_anchor = mx.nd.max(ious, axis=1)
            samples = mx.nd.zeros_like(ious_max_per_anchor)
    
            # set argmax (1, num_gt)
            ious_max_per_gt = mx.nd.max(ious, axis=0, keepdims=True)
            # ious (num_anchor, num_gt) >= argmax (1, num_gt) -> mark row as positive
            mask = mx.nd.broadcast_greater(ious + self._eps, ious_max_per_gt)
            # reduce column (num_anchor, num_gt) -> (num_anchor)
            mask = mx.nd.sum(mask, axis=1)
            # row maybe sampled by 2 columns but still only matches to most overlapping gt
            samples = mx.nd.where(mask, mx.nd.ones_like(samples), samples)
    
            # set positive overlap to 1
            samples = mx.nd.where(ious_max_per_anchor >= self._pos_iou_thresh,
                                  mx.nd.ones_like(samples), samples)
            # set negative overlap to -1
            tmp = (ious_max_per_anchor < self._neg_iou_thresh) * (ious_max_per_anchor >= 0)
            samples = mx.nd.where(tmp, mx.nd.ones_like(samples) * -1, samples)
    
            # subsample fg labels
            samples = samples.asnumpy()
            num_pos = int((samples > 0).sum())
            if num_pos > self._max_pos:
                disable_indices = np.random.choice(
                    np.where(samples > 0)[0], size=(num_pos - self._max_pos), replace=False)
                samples[disable_indices] = 0  # use 0 to ignore
    
            # subsample bg labels
            num_neg = int((samples < 0).sum())
            # if pos_sample is less than quota, we can have negative samples filling the gap
            max_neg = self._num_sample - min(num_pos, self._max_pos)
            if num_neg > max_neg:
                disable_indices = np.random.choice(
                    np.where(samples < 0)[0], size=(num_neg - max_neg), replace=False)
                samples[disable_indices] = 0
    
            # convert to ndarray
            samples = mx.nd.array(samples, ctx=matches.context)
            return samples, matches

  • 相关阅读:
    TCP 协议如何保证可靠传输
    mysql 优化
    Navicat 导入导出
    Hibernate的优缺点
    寒假学习日报(十八)
    《机器学习十讲》第二讲总结
    寒假学习日报(十七)
    《设计原本》阅读笔记(二)
    《机器学习十讲》第一讲总结
    寒假学习日报(十六)
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10192410.html
Copyright © 2011-2022 走看看