zoukankan      html  css  js  c++  java
  • TensorRT推理加速基于Tensorflow(keras)的uff格式模型(文件准备)

    一、引子//Windows

    tf(keras)训练好了模型,想要用Nvidia-TensorRT来重构训练好的模型为TRT推理引擎加快推理的速度。

    二、准备文件

    1、训练好模型以后(keras)可以通过以下方式保存keras模型为h5文件

    tf.keras.models.save_model(model, 'keras_model\\classify.h5')
    

     2、再通过以下代码来将h5文件转化为pb文件

    import tensorflow.compat.v1 as tf1
    
    tf1.reset_default_graph()
    tf1.keras.backend.set_learning_phase(0)  # 调用模型前一定要执行该命令
    tf1.disable_v2_behavior()  # 禁止tensorflow2.0的行为
    # 加载hdf5模型
    hdf5_pb_model = tf1.keras.models.load_model('keras_model\\classify.h5')
    
    
    def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
        graph = session.graph
        with graph.as_default():
            #         freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or []))
            output_names = output_names or []
            #         output_names += [v.op.name for v in tf1.global_variables()]
            print("output_names", output_names)
            input_graph_def = graph.as_graph_def()
            #         for node in input_graph_def.node:
            #             print('node:', node.name)
            print("len node1", len(input_graph_def.node))
            if clear_devices:
                for node in input_graph_def.node:
                    node.device = ""
            frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                                         output_names)
    
            outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)  # 云掉与推理无关的内容
            print("##################################################################")
            for node in outgraph.node:
                print('node:', node.name)
            print("len node1", len(outgraph.node))
            return outgraph
    
    
    output_folder2 = 'keras_model'
    
    frozen_graph = freeze_session(tf1.compat.v1.keras.backend.get_session(),
                                  output_names=[out.op.name for out in hdf5_pb_model.outputs])
    tf1.train.write_graph(frozen_graph, output_folder2, "classify.pb", as_text=False)
    

    3、注意:以上代码基于tf2.0运行

    4、pb模型文件转化为uff模型文件(tensorrt解析tf模型只能用uff格式)

    首先,先安装TensorRT自带的(两个文件就在trt文件夹里面,cd到路径)

    pip install uff-0.6.5-py2.py3-none-any.whl
    pip install graphsurgeon-0.4.1-py2.py3-none-any.whl
    

     5、执行(cd到路径,执行以下过程需要tf1.x版本,否则报错,没有Graphdef)

    转换

    convert-to-uff xxxx.pb -o xxxx.uff
    

    查看模型信息

    convert-to-uff xxxx.uff -l
    

     参考:

    【Tensorflow2.0】8、tensorflow2.0_hdf5_savedmodel_pb模型转换
    
  • 相关阅读:
    前台js加密实例
    Redis 核心原理
    Rredis的安装与常用命令
    Zookeeper的源码环境的搭建和源码解读
    Zookeeper集群搭建
    Zookeeper的客户端API使用
    Zookeeper介绍
    HashMap的死锁 与 ConcurrentHashMap
    定时任务 & 定时线程池 ScheduledThreadPoolExecutor
    Fork/Join框架
  • 原文地址:https://www.cnblogs.com/buctyk/p/12932663.html
Copyright © 2011-2022 走看看