zoukankan      html  css  js  c++  java
  • 查看tensorflow pb模型文件的节点信息

    查看tensorflow pb模型文件的节点信息:

    import tensorflow as tf
    with tf.Session() as sess:
        with open('./quantized_model.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read()) 
            print graph_def
            

    效果:

    # ...
    node {
      name: "FullyConnected/BiasAdd"
      op: "BiasAdd"
      input: "FullyConnected/MatMul"
      input: "FullyConnected/b/read"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      attr {
        key: "data_format"
        value {
          s: "NHWC"
        }
      }
    }
    node {
      name: "FullyConnected/Softmax"
      op: "Softmax"
      input: "FullyConnected/BiasAdd"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
    }
    library {
    }

    参考:https://tang.su/2017/01/export-TensorFlow-network/

    https://github.com/tensorflow/tensorflow/issues/15689

    一些核心代码:

    import tensorflow as tf
    with tf.Session() as sess:
        with open('./graph.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read()) 
            print graph_def
            output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
            print(sess.run(output))

    This is part of my Tensorflow frozen graph, I have named the input and output nodes.

    >>> g.ParseFromString(open('frozen_graph.pb','rb').read())
    >>> g
    node {
      name: "input"
      op: "Placeholder"
      attr {
        key: "dtype"
        value {
          type: DT_FLOAT
        }
      }
      attr {
        key: "shape"
        value {
          shape {
            dim {
              size: -1
            }
            dim {
              size: 68
            }
          }
        }
      }
    }
    ...
    node {
      name: "output"
      op: "Softmax"
      input: "add"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
    }
    

    I ran this model by the following code
    (CELL is name of directory where my file is located)

    final String MODEL_FILE = "file:///android_asset/" + CELL + "/optimized_graph.pb" ;
    final String INPUT_NODE = "input" ;
    final String OUTPUT_NODE = "output" ;
    final int[] INPUT_SIZE = {1,68} ;
    float[] RESULT = new float[8];
    
    inferenceInterface = new TensorFlowInferenceInterface();
    inferenceInterface.initializeTensorFlow(getAssets(),MODEL_FILE) ;
    inferenceInterface.fillNodeFloat(INPUT_NODE,INPUT_SIZE,input);
    

    and finally

    inferenceInterface.readNodeFloat(OUTPUT_NODE,RESULT);
    
  • 相关阅读:
    mybatis批量插入数据
    oracle的dmp数据文件的导出和导入以及创建用户
    maven安装第三方jar包到本地仓库
    IntelliJ IDEA 注册码,激活
    分布式事务实现-Spanner
    Redis Cluster原理
    twemproxy源码分析
    Paxos可容错的一致性协议
    UpdateServer事务实现机制
    Coroutine及其实现
  • 原文地址:https://www.cnblogs.com/bonelee/p/8462578.html
Copyright © 2011-2022 走看看