zoukankan      html  css  js  c++  java
  • 我是如何使计算提速>150倍的

    我是如何使计算提速>150倍的

    我的原始文档:https://www.yuque.com/lart/blog/lwgt38

    书接上文《我是如何使计算时间提速25.6倍》.

    上篇文章提到, F-measure使用累计直方图可以进一步加速计算, 但是E-measure却没有改出来. 在写完上篇文章的那个晚上, 重新整理思路后, 我似乎想到了如何去使用累计直方图来再次提速.

    速度的制约

    虽然使用"解耦"的思路可以高效优化每一个阈值下指标的计算过程, 但是整体的 for 循环确实仍然会占用较大的时间. 又考虑到各个阈值下的计算实际上并无太大关联, 如果可以实现同时计算, 那必然可以进一步提升速度. 这里我们又要把目光放回到在计算F-measure时大放光彩的累计直方图的策略上.

    在前面的解耦之后, 实际上获得的关键变量是 fg_fg_numelfg_bg_numel .

    fg_fg_numel = np.count_nonzero(binarized_pred & gt)
    fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
    

    从这两个变量本身入手, 如果使用累计直方图的话, 实际上可以同时获得 >=不同阈值 下的前景像素(值为1)的数量, 计算的本质和 np.count_nonzero 是一样的东西. 所以我们可以进行直观的替换:

    """
    函数内部变量命名规则:
        pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
        如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
    """
    fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
    fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
    fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
    fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
    

    这样我们就获得了不同阈值下的对应的一系列 fg_fg_numelfg_bg_numel 了. 这里需要注意的是, 使用的划分区间 bins 的设置. 由于默认的 histogram 划分的区间会包含最后一个端点, 所以比较合理的划分是 bins = np.linspace(0, 256, 257) , 这样最后一个区间是 [255, 256] , 就可以包含到最大的值, 又不会和 254 重复计数.

    为了便于计算, 这里将后面会用到的 pred 前景统计 fg___numel_w_thrs 和背景统计 bg____numel_w_thrs 直接写出来, 便于使用:

    fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
    bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
    

    后面的步骤和之前的基本一致, numpy的广播机制使得不需要改动太多. 由于这部分代码实际上再多处位置会被使用, 所以提取成一个单独的方法.

    def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
        bg_fg_numel = self.gt_fg_numel - fg_fg_numel
        bg_bg_numel = pred_bg_numel - bg_fg_numel
    
        parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]
    
        mean_pred_value = pred_fg_numel / self.gt_size
        mean_gt_value = self.gt_fg_numel / self.gt_size
    
        demeaned_pred_fg_value = 1 - mean_pred_value
        demeaned_pred_bg_value = 0 - mean_pred_value
        demeaned_gt_fg_value = 1 - mean_gt_value
        demeaned_gt_bg_value = 0 - mean_gt_value
    
        combinations = [
            (demeaned_pred_fg_value, demeaned_gt_fg_value),
            (demeaned_pred_fg_value, demeaned_gt_bg_value),
            (demeaned_pred_bg_value, demeaned_gt_fg_value),
            (demeaned_pred_bg_value, demeaned_gt_bg_value)
        ]
        return parts_numel, combinations
    

    后面计算 enhanced_matrix_sum 的部分也就顺理成章比较自然的可以写出来:

    parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
        fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
        pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
    )
    
    # 这里虽然可以使用列表来收集各个results_part,但是列表之后还需要再转为numpy数组来求和,倒不如直接一次性申请好空间后面直接装入即可
    results_parts = np.empty(shape=(4, 256), dtype=np.float64)
    for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
        align_matrix_value = 2 * (combination[0] * combination[1]) / 
                                (combination[0] ** 2 + combination[1] ** 2 + _EPS)
        enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
        results_parts[i] = enhanced_matrix_value * part_numel
    enhanced_matrix_sum = results_parts.sum(axis=0)
    

    整体梳理

    主要逻辑已经搞定, 接下来就是将这些代码与原始的代码融合起来, 也就是整合原始代码的 cal_em_with_thresholdcal_enhanced_matrix 两个方法.

    原始代码中 https://github.com/lartpang/CodeForArticle/blob/7a922c720702c727d7a28fd17f3db66e0b9ba135/sod_metrics/metrics/metric_best.py#L46-L58

    def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
        binarized_pred = pred >= threshold
    
        if self.gt_fg_numel == 0:
            binarized_pred_bg_numel = np.count_nonzero(~binarized_pred)
            enhanced_matrix_sum = binarized_pred_bg_numel
        elif self.gt_fg_numel == self.gt_size:
            binarized_pred_fg_numel = np.count_nonzero(binarized_pred)
            enhanced_matrix_sum = binarized_pred_fg_numel
        else:
            enhanced_matrix_sum = self.cal_enhanced_matrix(binarized_pred, gt)
        em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
        return em
    

    结合前面代码中计算出的各个阈值下的前背景元素的统计值, 上面这里的代码实际上可以通过使用现有运算结果进行化简, 即 if 的前两个分支. 另外阈值划分也不需要显式处理, 因为已经在累计直方图中搞定了. 所以这里的代码对于动态阈值计算的情况下, 是可以被合并到 cal_enhanced_matrix 的计算过程中的. 直接得到最终的整合后的方法:

    def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
        """
        函数内部变量命名规则:
            pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
            如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
        """
        pred = (pred * 255).astype(np.uint8)
        bins = np.linspace(0, 256, 257)
        fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
        fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
        fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
        fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
    
        fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
        bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
    
        if self.gt_fg_numel == 0:
            enhanced_matrix_sum = bg___numel_w_thrs
        elif self.gt_fg_numel == self.gt_size:
            enhanced_matrix_sum = fg___numel_w_thrs
        else:
            parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
                fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
                pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
            )
    
            results_parts = np.empty(shape=(4, 256), dtype=np.float64)
            for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
                align_matrix_value = 2 * (combination[0] * combination[1]) / 
                                        (combination[0] ** 2 + combination[1] ** 2 + _EPS)
                enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
                results_parts[i] = enhanced_matrix_value * part_numel
            enhanced_matrix_sum = results_parts.sum(axis=0)
    
        em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
        return em
    

    还是为了重用, cal_em_with_threshold (该方法需要保留, 因为还有另一种E-measure的计算情况需要用到该方法)可以被重构:

    def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
        """
        函数内部变量命名规则:
            pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
            如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
        """
        binarized_pred = pred >= threshold
        fg_fg_numel = np.count_nonzero(binarized_pred & gt)
        fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
    
        fg___numel = fg_fg_numel + fg_bg_numel
        bg___numel = self.gt_size - fg___numel
    
        if self.gt_fg_numel == 0:
            enhanced_matrix_sum = bg___numel
        elif self.gt_fg_numel == self.gt_size:
            enhanced_matrix_sum = fg___numel
        else:
            parts_numel, combinations = self.generate_parts_numel_combinations(
                fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
                pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
            )
    
            results_parts = []
            for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
                align_matrix_value = 2 * (combination[0] * combination[1]) / 
                                        (combination[0] ** 2 + combination[1] ** 2 + _EPS)
                enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
                results_parts.append(enhanced_matrix_value * part_numel)
            enhanced_matrix_sum = sum(results_parts)
    
        em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
        return em
    

    效率对比

    使用本地的845张灰度预测图和二值mask真值数据进行测试比较, 重新跑了一遍, 总体时间对比如下:

    方法 总体耗时(s) 速度提升(倍)
    'base' 539.2173762321472s x1
    'best' 19.94518733024597s x27.0 (539.22/19.95)
    'cumsumhistogram' 3.2935903072357178s x163.8 (539.22/3.29)

    还是那句话, 虽然具体时间可能还受硬件限制, 但是相对快慢还是比较明显的.

    测试代码可见我的 github : https://github.com/lartpang/CodeForArticle/tree/main/sod_metrics

  • 相关阅读:
    selenium webdriver使用过程中出现Element is not currently visible and so may not be interacted with的处理方法
    转:使用C#的HttpWebRequest模拟登陆网站
    第一个WCF服务
    WinForm DataGridview.AutoSizeColumnsMode属性
    Winform 程序中DataGridView 控件添加行号
    DevExpress GridControl行颜色标识
    基于T4模板的文档生成
    NHibernate 支持的数据库及配置参数
    MahApps.Metro
    Log4Net 配置文件样例
  • 原文地址:https://www.cnblogs.com/lart/p/14063080.html
Copyright © 2011-2022 走看看