zoukankan      html  css  js  c++  java
  • 两种从 TensorFlow 的 checkpoint生成 frozenpb 的方法

    1. 从 ckpt-.data,ckpt-.index 和 .meta 生成 frozenpb

    import os
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    
    def freeze_graph(input_checkpoint,output_graph):
        '''
        :param input_checkpoint:
        :param output_graph: PB模型保存路径
        :return:
        '''
        # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
        output_node_names = "outputs"
        saver = tf.train.import_meta_graph(os.path.join(os.path.split(input_checkpoint)[0], 'graph.meta'), clear_devices=True)
     
        with tf.Session() as sess:
            saver.restore(sess, input_checkpoint) #恢复图并得到数据
            output_graph_def = graph_util.convert_variables_to_constants(  
                # 模型持久化,将变量值固定
                sess=sess,
                input_graph_def=sess.graph_def,# 等于:sess.graph_def
                output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
     
            with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
                f.write(output_graph_def.SerializeToString()) #序列化输出
            print("%d ops in the final graph." % len(output_graph_def.node)) 
            #得到当前图有几个操作节点
    
    if __name__ == "__main__":
        # 输入ckpt模型路径
        input_checkpoint='ckpt_path/ckpt-10000'
        # 输出pb模型的路径
        out_pb_path="some_path/frozen_model.pb"
        # 调用freeze_graph将ckpt转为pb
        freeze_graph(input_checkpoint,out_pb_path)
    
    

    2. 从网络代码和 ckpt-.data 文件生成 frozenpb

    import tensorflow as tf
    import os
    from tensorflow.python.tools import freeze_graph
    
    import network  # 导入网络结构
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 设置GPU
    model_path = "ckpt_path/ckpt-10000"
    
    def main():
        tf.reset_default_graph()
        input_node = tf.placeholder(
            tf.float32, shape=(None,112, 96, 3)
        ) 
        input_node = tf.identity(input_node,name="inputs") # 设置输入节点的名字,这里可以自定义名称
        flow = network(input_node)
        flow = tf.identity(flow, name="outs") # 设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
        saver = tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess, model_path)
            # 保存图
            tf.train.write_graph(sess.graph_def, "logdir/", "graph.pb")
            # 把图和参数结构一起
            freeze_graph.freeze_graph(
                "logdir/graph.pb", # 上面保存的图结构 graph.pb
                "",
                False,
                model_path,
                "outs",
                "save/restore_all", # 默认恢复所有
                "save/Const:0", # 默认常量
                "some_path/frozen.pb", # 保存frozen.pb
                False,
                "",
            )
        print("done")
    
    
    if __name__ == "__main__":
        main()
    

    3. 打印 网络中节点的名字

    import tensorflow as tf
    
    
    if __name__ == "__main__":
        checkpoint_path = '../model_fintune/ckpt-1400'  
        reader = tf.train.NewCheckpointReader(checkpoint_path)  
        var_to_shape_map = reader.get_variable_to_shape_map()  
        
        for key in var_to_shape_map:  
            print("tensor name: ", key)  
            # print(reader.get_tensor(key))
    

    或者通过

    import tensorflow as tf
    
    def printTensors(pb_file):
    
        # read pb into graph_def
        with tf.gfile.GFile(pb_file, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
        # import graph_def
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def)
    
        # print operations
        for op in graph.get_operations():
            print(op.name)
    
    printTensors("path-to-my-pbfile.pb")
    

    4. 两种方法对比

    如果是自己的代码训练的模型,有网络结构,有 ckpt 文件,最好是使用第二种方法,使用起来很灵活,可以进行各种自定义,比如修改输入输出的节点名字,网络有多个路径的时候可以自定义输出路径。第一种方法,应该也能达到第二种方法的效果,因为它们本来就是等价的,可能会有些麻烦。第一种方法的好处就是快,不要去翻那些杂糅在一起的网络结构。

  • 相关阅读:
    查看端口被占用
    Eclipse导入包
    Eclipse中构造方法自动生成
    Eclipse中get/set方法自动生成
    Eclipse改字体大小
    设计六原则
    类的关系
    JAVA实现多线程下载
    try...catch的前世今生
    447. 回旋镖的数量
  • 原文地址:https://www.cnblogs.com/willwell/p/12196101.html
Copyright © 2011-2022 走看看