zoukankan      html  css  js  c++  java
  • 计算两张图的余弦相似度

    # 结果余弦相似度对比
    import numpy as np
    import pdb
     
    def count_difference(groundtruth, inputs):
        statistical_method = {
            'cosine_similarity':
            lambda X1, X2: np.sum(X1 * X2) /
            (np.sqrt(np.sum(X1**2)) * np.sqrt(np.sum(X2**2))),
            'maximum_absolute_error':
            lambda X1, X2: np.max(np.abs(X1 - X2)).tolist(),
            'accumulated_relative_error':
            lambda X1, X2: np.sum(np.abs(np.nan_to_num((X1 - X2) / X1))),
            'relative_euclidean_distance':
            lambda X1, X2: np.sqrt(
                np.sum((X1 / np.sqrt(np.sum(X1**2)) - X2 / np.sqrt(np.sum(X2**2)))
                       **2)).tolist(),
            'kullback_leibler_divergence':
            lambda X1, X2: np.sum(X1 * np.nan_to_num(np.log(X1 / X2))),
            'standard_deviation':
            lambda *X: [(np.mean(x).tolist(), np.std(x).tolist()) for x in X]
        }
     
        reports = {}
        for input_key in groundtruth.keys():
            if input_key in inputs.keys():
                reports[input_key] = {}
                gt_input = groundtruth[input_key].reshape(-1, 1)
                compare_input = inputs[input_key].reshape(-1, 1)
                assert gt_input.size == compare_input.size# 要求对比的两张图尺寸一致
                for key, value in statistical_method.items():
                    reports[input_key][key] = value(gt_input, compare_input)
        return reports
     
     
    def main():
        
        a = np.fromfile("/home/wangmaolin/for_test/tofile/conv_82_memory", dtype=np.float32)
        print(a.shape)
        print(a.dtype)
        inputs = {"data": a}
     
        c = np.fromfile("/home/wangmaolin/for_test/onnx_output/onnx_output_conv_82", dtype=np.float32)
        print(c.shape)
        print(c.dtype)
        gt_inputs = {"data": c}
        
     
        report = count_difference(gt_inputs, inputs)
        print(report)
     
     
    if __name__ == '__main__':
        main()
    
    转载请注明出处
  • 相关阅读:
    实现跨域的几种方法
    2015-07-15
    unity3d中给GameObject绑定脚本的代码
    unity3d的碰撞检测及trigger
    区块链 (未完)
    mono部分源码解析
    量化策略分析的研究内容
    mono搭建脚本整理
    unity3d简介
    Hook技术之API拦截(API Hook)
  • 原文地址:https://www.cnblogs.com/lnlin/p/15490644.html
Copyright © 2011-2022 走看看