zoukankan      html  css  js  c++  java
  • TensorRT推理加速基于Tensorflow(keras)的uff格式模型(文件准备)

    一、引子//Windows

    tf(keras)训练好了模型,想要用Nvidia-TensorRT来重构训练好的模型为TRT推理引擎加快推理的速度。

    二、准备文件

    1、训练好模型以后(keras)可以通过以下方式保存keras模型为h5文件

    tf.keras.models.save_model(model, 'keras_model\\classify.h5')
    

     2、再通过以下代码来将h5文件转化为pb文件

    import tensorflow.compat.v1 as tf1
    
    tf1.reset_default_graph()
    tf1.keras.backend.set_learning_phase(0)  # 调用模型前一定要执行该命令
    tf1.disable_v2_behavior()  # 禁止tensorflow2.0的行为
    # 加载hdf5模型
    hdf5_pb_model = tf1.keras.models.load_model('keras_model\\classify.h5')
    
    
    def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
        graph = session.graph
        with graph.as_default():
            #         freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or []))
            output_names = output_names or []
            #         output_names += [v.op.name for v in tf1.global_variables()]
            print("output_names", output_names)
            input_graph_def = graph.as_graph_def()
            #         for node in input_graph_def.node:
            #             print('node:', node.name)
            print("len node1", len(input_graph_def.node))
            if clear_devices:
                for node in input_graph_def.node:
                    node.device = ""
            frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                                         output_names)
    
            outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)  # 云掉与推理无关的内容
            print("##################################################################")
            for node in outgraph.node:
                print('node:', node.name)
            print("len node1", len(outgraph.node))
            return outgraph
    
    
    output_folder2 = 'keras_model'
    
    frozen_graph = freeze_session(tf1.compat.v1.keras.backend.get_session(),
                                  output_names=[out.op.name for out in hdf5_pb_model.outputs])
    tf1.train.write_graph(frozen_graph, output_folder2, "classify.pb", as_text=False)
    

    3、注意:以上代码基于tf2.0运行

    4、pb模型文件转化为uff模型文件(tensorrt解析tf模型只能用uff格式)

    首先,先安装TensorRT自带的(两个文件就在trt文件夹里面,cd到路径)

    pip install uff-0.6.5-py2.py3-none-any.whl
    pip install graphsurgeon-0.4.1-py2.py3-none-any.whl
    

     5、执行(cd到路径,执行以下过程需要tf1.x版本,否则报错,没有Graphdef)

    转换

    convert-to-uff xxxx.pb -o xxxx.uff
    

    查看模型信息

    convert-to-uff xxxx.uff -l
    

     参考:

    【Tensorflow2.0】8、tensorflow2.0_hdf5_savedmodel_pb模型转换
    
  • 相关阅读:
    tyvj 1031 热浪 最短路
    【bzoj2005】 [Noi2010]能量采集 数学结论(gcd)
    hdu 1394 Minimum Inversion Number 逆序数/树状数组
    HDU 1698 just a hook 线段树,区间定值,求和
    ZeptoLab Code Rush 2015 C. Om Nom and Candies 暴力
    ZeptoLab Code Rush 2015 B. Om Nom and Dark Park DFS
    ZeptoLab Code Rush 2015 A. King of Thieves 暴力
    hdoj 5199 Gunner map
    hdoj 5198 Strange Class 水题
    vijos 1659 河蟹王国 线段树区间加、区间查询最大值
  • 原文地址:https://www.cnblogs.com/buctyk/p/12932663.html
Copyright © 2011-2022 走看看