zoukankan      html  css  js  c++  java
  • 查看tensorflow pb模型文件的节点信息

    查看tensorflow pb模型文件的节点信息:

    import tensorflow as tf
    with tf.Session() as sess:
        with open('./quantized_model.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read()) 
            print graph_def
            

    效果:

    # ...
    node {
      name: "FullyConnected/BiasAdd"
      op: "BiasAdd"
      input: "FullyConnected/MatMul"
      input: "FullyConnected/b/read"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
      attr {
        key: "data_format"
        value {
          s: "NHWC"
        }
      }
    }
    node {
      name: "FullyConnected/Softmax"
      op: "Softmax"
      input: "FullyConnected/BiasAdd"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
    }
    library {
    }

    参考:https://tang.su/2017/01/export-TensorFlow-network/

    https://github.com/tensorflow/tensorflow/issues/15689

    一些核心代码:

    import tensorflow as tf
    with tf.Session() as sess:
        with open('./graph.pb', 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read()) 
            print graph_def
            output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
            print(sess.run(output))

    This is part of my Tensorflow frozen graph, I have named the input and output nodes.

    >>> g.ParseFromString(open('frozen_graph.pb','rb').read())
    >>> g
    node {
      name: "input"
      op: "Placeholder"
      attr {
        key: "dtype"
        value {
          type: DT_FLOAT
        }
      }
      attr {
        key: "shape"
        value {
          shape {
            dim {
              size: -1
            }
            dim {
              size: 68
            }
          }
        }
      }
    }
    ...
    node {
      name: "output"
      op: "Softmax"
      input: "add"
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }
    }
    

    I ran this model by the following code
    (CELL is name of directory where my file is located)

    final String MODEL_FILE = "file:///android_asset/" + CELL + "/optimized_graph.pb" ;
    final String INPUT_NODE = "input" ;
    final String OUTPUT_NODE = "output" ;
    final int[] INPUT_SIZE = {1,68} ;
    float[] RESULT = new float[8];
    
    inferenceInterface = new TensorFlowInferenceInterface();
    inferenceInterface.initializeTensorFlow(getAssets(),MODEL_FILE) ;
    inferenceInterface.fillNodeFloat(INPUT_NODE,INPUT_SIZE,input);
    

    and finally

    inferenceInterface.readNodeFloat(OUTPUT_NODE,RESULT);
    
  • 相关阅读:
    前端3
    前端-1
    第三十七章 MYSQL(二)
    第三十六章 MYSQL语句(一)
    第三十五 MYSQL 语句
    数字转换成中文大小写、金额大小写
    NPOI随笔——图片在单元格等比缩放且居中显示
    NPOI随笔——单元格样式CellStyle问题
    C++、C#、VB各语言日志代码
    .NET认识与理论总结
  • 原文地址:https://www.cnblogs.com/bonelee/p/8462578.html
Copyright © 2011-2022 走看看