zoukankan      html  css  js  c++  java
  • PB文件相关操作

    一、获取pb模型的节点名称

    import tensorflow as tf
    import os
    
    
    model_dir = ‘  ’
    model_name = ' '
    
    def create_graph():
        with tf.gfile.FastGFile(os.path.join( model_dir, model_name), 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
    
    create_graph()
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    f = open('/home/yk/Desktop/Conv-TasNet-master-20190508/op.txt', 'wb')
    for tensor_name in tensor_name_list:
        print(tensor_name,'
    ')
        f.write(tensor_name + '
    ')

    二、ckpt转换为pb模型

    from tensorflow.python.tools import inspect_checkpoint as chkp
    import tensorflow as tf
    
    saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True)
    
    #【敲黑板!】这里就是填写输出节点名称惹
    output_nodes = ["xxx"] 
    
    with tf.Session(graph=tf.get_default_graph()) as sess:
        input_graph_def = sess.graph.as_graph_def()
        saver.restore(sess, "./ade20k/model.ckpt-27150")
        output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_nodes)
        with open("frozen_model.pb", "wb") as f:
            f.write(output_graph_def.SerializeToString())

    三、pb TensorBoard 可视化

    1. 从pb文件中恢复计算图
    import tensorflow as tf
    
    model = 'model.pb' #请将这里的pb文件路径改为自己的
    graph = tf.get_default_graph()
    graph_def = graph.as_graph_def()
    graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
    tf.import_graph_def(graph_def, name='graph')
    summaryWriter = tf.summary.FileWriter('log/', graph)
    
    执行以上代码就会生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。
    
    
    2. 在tensorboard中加载
    
    tensorboard --logdir path/to/log
    
    3. 在浏览器中打开链接

    附加:ckpt模型节点获取

    import os
    from tensorflow.python import pywrap_tensorflow
    
    checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        # print(reader.get_tensor(key)) #相应的值
  • 相关阅读:
    非对称加密-RSA公钥加密,私钥解密,私钥加签,公钥验签
    设置mysql数据库本地连接或外部可连接
    mysql自增长主键,删除数据后,将主键顺序重新排序
    非Service层和Controller层调用ssm框架中的方法
    DES加密算法(密文只有字符串和数字)java和android加密的结果一致(可放在url中)
    SpringBoot ajax Restful整合
    java中线程执行流程详解
    在 CSS 中直接引用 fontawesome 图标(附码表)
    C++内存管理~
    操作系统那些事儿
  • 原文地址:https://www.cnblogs.com/kang06/p/10832578.html
Copyright © 2011-2022 走看看