zoukankan      html  css  js  c++  java
  • tf导出pb文件,以及如何使用pb文件

    先罗列出来代码,有时间再解释

    from tensorflow.python.framework import graph_util
    import tensorflow as tf
    
    
    
    def export_model(input_checkpoint, output_graph):
        #这个可以加载saver的模型
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
        graph = tf.get_default_graph() # 获得默认的图
        input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            
            saver.restore(sess, input_checkpoint)
            output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=['softmax_linear/softmax_linear','Cast_1'])# 如果有多个输出节点,以逗号隔开这个是重点,输入和输出的参数都需要在这里记录
    
            with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
                f.write(output_graph_def.SerializeToString()) #序列化输出
            
    
    export_model('E:\python\image\code2\model_10\model.ckpt',"E:\python\image\code2\model_10\model.pb")
    

    使用的代码

    import os
    import numpy as np
    import tensorflow as tf
    import model_new
    from PIL import Image
    import matplotlib.pyplot as plt
    import csv
    import shutil
    from tensorflow.python.platform import gfile
    
    def get_one_image(img_dir):
            
            image = Image.open(img_dir)
            
            image = image.resize((128,128))
            image = np.array(image)
    
            return image, img_dir
    
    
    def test_model(model_path, img_path):
        image_array,img_dir = get_one_image( img_path)
        image = tf.cast(image_array,tf.float32)
        #image = tf.image.per_image_standardization(image)
        image = tf.reshape(image,[1,128,128,3])
    
        with tf.Session() as sess:
           
            with gfile.FastGFile(model_path,'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                sess.graph.as_default()
                tf.import_graph_def(graph_def,name='')
            sess.run(tf.global_variables_initializer())
            
            input_x = sess.graph.get_tensor_by_name('Cast_1:0')
            out = sess.graph.get_tensor_by_name('softmax_linear/softmax_linear:0')
            ret = sess.run(out,  feed_dict={input_x: image.eval()})
            print(ret)
            
    
    
    
    out_pb_path="E:\python\image\code2\model_10\frozen_model.pb"
    img_path = "E:\python\image\code\images\0\mmexport1540880139708.jpg"
    test_model(out_pb_path,img_path)
    
  • 相关阅读:
    python基础十一之装饰器进阶
    python基础十之装饰器
    python基础九之函数
    python基础八之文件操作
    python基础七之copy
    python基础七之集合
    python基础数据类型汇总
    python基础六之编码
    synchronized关键字的内存语义
    对于this和当前线程的一些理解
  • 原文地址:https://www.cnblogs.com/bbird/p/9951943.html
Copyright © 2011-2022 走看看