zoukankan      html  css  js  c++  java
  • tensorflow_tflite专题

    tensorflow_tflite专题
    本文章主要包括两大问题:

    tflite的转换:如何转换得到tflite?
    tflite的测试:如何测试或者说如何在PC端使用tflite?

    问题一:如何转换得到tflite

    分为两个过程,步骤:cheakpoint→pb模型→tflite模型

    • step1:cheakpoint→tflite_graph.pb:
      使用object_detection的export_tflite_ssd_graph.py,结果生成tflite_graph.pb和tflite_graph.pbtxt两个文件

    超参数:
    "output_directory":输出的文件夹
    "pipeline_config_path":网络配置文件
    "trained_checkpoint_prefix":你的cheakpoint文件

    • step2:tflite_graph.pb→out_put.tflite:
      使用convert.py程序讲pb转换为tflite,这里的pb是上一步转换得到了,不能是其他来源的pb模型
    import tensorflow as tf
    
    # 需要配置
    in_path = "tflite_graph.pb"
    
    # 模型输入节点对于object_detection是固定的,不需改动,但是shape是和网络有关
    input_tensor_name = ["normalized_input_image_tensor"]
    input_tensor_shape = {"normalized_input_image_tensor":[1,256,256,3]}
    # 模型输出节点,对于object_detection是固定的,不需改动
    classes_tensor_name = ['TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']
    
    converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path,input_tensor_name, classes_tensor_name,input_tensor_shape)
    
    converter.allow_custom_ops=True
    converter.post_training_quantize = True
    tflite_model = converter.convert()
    
    open("output.tflite", "wb").write(tflite_model)
    
    print("done")
    
    

    问题二:如何测试或者说如何在PC端使用tflite?

    这里给出代码:

    import numpy as np
    import tensorflow as tf
    import cv2 #用来读取图片并进行预处理
    import glob #读取某文件夹所有测试图片
    import time #主要是用来计算推理花费时间
    
    # Load TFLite model and allocate tensors.
    model_path="output_fp16.tflite"  #tflite路径
    interpreter = tf.lite.Interpreter(model_path)
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print(input_details)
    print(output_details) #在这里可以看到tflite的输入输出的节点信息
    
    def detection(img_src):
        img = cv2.resize(img_src, (256, 256))
        img = img / 128 - 1
        input_data = np.expand_dims(img, 0)
        input_data = input_data.astype(np.float32)
        #以上是对图片经行尺寸变换、归一化、添加维度和类型转换,以便和输入节点对应
    
        index = input_details[0]['index']
        interpreter.set_tensor(index, input_data)
        interpreter.invoke() #启动
        
        output0 = interpreter.get_tensor(output_details[0]['index'])  # bbox
        output1 = interpreter.get_tensor(output_details[1]['index'])  # bbox
        output2 = interpreter.get_tensor(output_details[2]['index'])  # bbox
        output3 = interpreter.get_tensor(output_details[3]['index'])  # 概率
        #在这里你可以通过print查看4个输出的信息
        #分别时object_detection的信息:
        #对于我来讲,人脸检测不涉及类别,所以我只用到
        # output0:位置信息
        # output2:对应的概率
        
        #我只要概率最大的人脸,且概率>0.6保持,否则讲概率置为0
        output3=output3[0][0] if output3[0][0] > 0.6 else 0
        
        return bbox,output3 #返回概率信息和其位置信息
    
    imgs_path = glob.glob('../../test_iamge/*')
    
    
    for img_path in imgs_path:
        t1=time.time()
        img=cv2.imread(img_path)
        sp = img.shape
        bbox,confidence=detection(img)
        if confidence!=0:
            print('置信度=',confidence,'   bbox=',bbox,end='   ')
            y1 = int(bbox[0][0][0] * sp[0])
            x1 = int(bbox[0][0][1] * sp[1])
            y2 = int(bbox[0][0][2] * sp[0])
            x2 = int(bbox[0][0][3] * sp[1])
    
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 3)
            print('time=',time.time()-t1)
            cv2.namedWindow(str(confidence*100)[2:6]+'%', 0)
            cv2.imshow(str(confidence*100)[2:6]+'%', img)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        else:
            print('time=',time.time()-t1)
    
    
  • 相关阅读:
    centos7查看启动的进程并杀死
    3.3 Zabbix容器安装
    windows下XAMPP集成环境中,MySQL数据库的使用
    pip淘宝镜像安装
    服务起不来,查看ps axj 看服务是否为守护进程(TPGID 为-1)
    dcloud_base连接失败(root:admin123!@#qwe@tcp(192.168.8.205:3306)/dcloud_base) Error 1129: Host '192.168.8.205' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'
    服务部署_软加密之后要重新启动才能生效
    AWS Certified Solutions Architect
    Cloud Formation Mapping经常用于AMI ID的region映射
    CloudFormation StackSets
  • 原文地址:https://www.cnblogs.com/thgpddl/p/13550417.html
Copyright © 2011-2022 走看看