zoukankan      html  css  js  c++  java
  • 加载并运行onnx格式的model,并获取模型运行过程中的每一层输出

    运行model,获得模型最终输出结果

    """
    测试onnx文件,获取浮点模型输出
    wangmaolin-1029
    """
    
    import onnxruntime
    import pdb
    import cv2
    import numpy as np
    
    def test_seg(seg):
        print(seg.shape)
        cnt_nan = 0
        for i in range(80):
            for j in range(256):
                if np.isnan(seg[i][j]):
                    cnt_nan += 1
        print(cnt_nan, "nan")
        h = np.histogram(seg)
        print(h)
    
    def main():
        path = '/home/wangmaolin/models/orig-config-simplify.onnx'
        print('path:', path)
        
        image_path = '/home/wangmaolin/fixedpoint_21.2/allImage/CA_S202_7012C_FV180_V9_20200904_100626_050.mf4_remap_4I_G3NWide_0163229.png'
        image = cv2.imread(image_path, -1)
        image = image[512-320:, 320:320+1024]
        image_cpu = image[np.newaxis, np.newaxis,:]
        image_cpu = np.array(image_cpu, dtype=np.float32)
        
        for i in range(10):
            print(image[0, i])
    
        teacher_session = onnxruntime.InferenceSession(path)
        soft_label = teacher_session.run([], {'input': image_cpu})
    
        print('done.')        
    
    main()
    

    运行model,获取模型每一层的输出

    # 获取onnx模型的每一层输出结果 并统计每一层的min max
    # wangmaolin-1029
    # 统计模型每一层输出的min max,并计算所有总的min max
    
    import collections
    import onnx
    import onnxruntime
    import numpy as np
    import cv2
    import copy
    from collections import OrderedDict
    import pdb
    import os
    import json
    
    
    def get_layer_output(model, image):
        ori_output = copy.deepcopy(model.graph.output)
        
        for node in model.graph.node:
            for output in node.output:
                model.graph.output.extend([onnx.ValueInfoProto(name=output)])
        
        ort_session = onnxruntime.InferenceSession(model.SerializeToString())
        
        
        ort_inputs = {}
        
        for i, input_ele in enumerate(ort_session.get_inputs()):
            ort_inputs[input_ele.name] = image
            
        outputs = [x.name for x in ort_session.get_outputs()]
        ort_outs = ort_session.run(outputs, ort_inputs)
        ort_outs = OrderedDict(zip(outputs, ort_outs))
        
        return ort_outs
    
    def get_layer_min_max(ort_outs):
        layers_min_max = {}
        
        for key in ort_outs.keys():
            layers_min_max[key] = [np.min(ort_outs[key]), np.max(ort_outs[key])]
            print(key)
            print(layers_min_max[key])
           
        print(type(layers_min_max)) 
        return layers_min_max
    
    def tran_png2bin_input(image_path):
        image = cv2.imread(image_path, -1)
        image = image[512-320:, 320:320+1024]
        image_cpu = image[np.newaxis, np.newaxis,:]
        image_cpu = np.array(image_cpu, dtype=np.float32)
        
        return image_cpu
    
    if __name__ == '__main__':
        model = onnx.load('/home/wangmaolin/models/orig-config-simplify.onnx')
        
        image_dir = '/home/wangmaolin/for_test/min_max_images'
        
        image_path = '/home/wangmaolin/for_test/image_15.png'
        
        image_input = tran_png2bin_input(image_path)
          
        
        ort_outs = get_layer_output(model, image_input)
        
        all_layers_min_max = get_layer_min_max(ort_outs)
        
        #print(layers_min_max)
        
        #all_images_min_max = {}
        
        # json_data = json.loads(layers_min_max, object_hook=OrderedDict)
        
        # json_str = json.dumps(layers_min_max, indent=4)
        
        # with open('/home/wangmaolin/for_test/min_max_res/image15_min_max.json', 'w') as fp:
        #     fp.write(json_str)
        
        image_list = os.listdir(image_dir)
        image_list.sort()
        print("all images: ", len(image_list))
        
        for i, image_name in enumerate(image_list):
            print("image_name: ", image_name)
            
            image_path = os.path.join(image_dir, image_name)
            
            image_input = tran_png2bin_input(image_path)
          
            ort_outs = get_layer_output(model, image_input)
        
            layers_min_max = get_layer_min_max(ort_outs)
            print(layers_min_max)
            
            for key in layers_min_max.keys():
                min_value = min(all_layers_min_max[key][0], layers_min_max[key][0])
                max_value = max(all_layers_min_max[key][1], layers_min_max[key][1])
                
                all_layers_min_max[key][0] = min_value
                all_layers_min_max[key][1] = max_value
            
        print("all images layers min max res")
        print(all_layers_min_max)
            
        
        
        #pdb.set_trace()
        
    
    转载请注明出处
  • 相关阅读:
    jsp报源码
    c#简单写售票系统
    linux常用命令大全[转]
    【转载】大型网站渗透思之信息收集
    Ajax初窥
    屏蔽win10中文输入法
    win10禁止更新的方法
    win10进入到安全模式的三种方法
    7代CPU安装win7的方法
    python的输出问题
  • 原文地址:https://www.cnblogs.com/lnlin/p/15490648.html
Copyright © 2011-2022 走看看