zoukankan      html  css  js  c++  java
  • Tensorflow 模型保存与调用

    Tensorflow 两种保存模型的方式:pb 和  saved_model 都可以。

    1、pb

    1.1 模型保存成pb

    freozen_pb.py

     1 import tensorflow as tf
     2 from tensorflow.python.framework import graph_util
     3 
     4 
     5 
     6 with tf.Session(graph=tf.Graph()) as sess:
     7     x = tf.placeholder(tf.int32, name='in_x')
     8     y = tf.placeholder(tf.int32, name='in_y')
     9     b = tf.Variable(1, name='b')
    10     m = tf.multiply(x, y)
    11     a = tf.add(m, b, name='out_add')
    12 
    13     sess.run(tf.global_variables_initializer())
    14     constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['out_add'])
    15 
    16     feed_dict = {x: 10, y: 3}
    17     print(sess.run(a, feed_dict))
    18 
    19     with tf.gfile.FastGFile('./model.pb', mode='wb') as f:
    20         f.write(constant_graph.SerializeToString())

    1.2 调用pb模型

    call_pb.py

     1 import tensorflow as tf
     2 from tensorflow.python.platform import gfile
     3 
     4 
     5 sess = tf.Session()
     6 with gfile.FastGFile('./model.pb', 'rb') as f:
     7     graph_def = tf.GraphDef()
     8     graph_def.ParseFromString(f.read())
     9     sess.graph.as_default()
    10     tf.import_graph_def(graph_def, name='')
    11 
    12 sess.run(tf.global_variables_initializer())
    13 #print(sess.run('b:0'))
    14 
    15 in_x = sess.graph.get_tensor_by_name('in_x:0')
    16 in_y = sess.graph.get_tensor_by_name('in_y:0')
    17 out_add = sess.graph.get_tensor_by_name('out_add:0')
    18 
    19 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 9})
    20 print(ret)

    2、 saved_model

    2.1 模型保存成saved model

    freozen_sm.py

     1 import os
     2 import tensorflow as tf
     3 
     4 saved_model_path = os.getcwd()
     5 
     6 with tf.Session(graph=tf.Graph()) as sess:
     7     x = tf.placeholder(tf.int32, name='in_x')
     8     y = tf.placeholder(tf.int32, name='in_y')
     9     b = tf.Variable(1, name='b')
    10     m = tf.multiply(x, y)
    11     a = tf.add(m, b, name='out_add')
    12 
    13     sess.run(tf.global_variables_initializer())
    14 
    15     tf.saved_model.simple_save(sess, './sm', {'in_x': x, 'in_y': y}, {'out_add': a}, )

    2.2 调用saved model模型

    call_sm.py

     1 import tensorflow as tf
     2 
     3 sess = tf.Session()
     4 tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], './sm')
     5 in_x = sess.graph.get_tensor_by_name('in_x:0')
     6 in_y = sess.graph.get_tensor_by_name('in_y:0')
     7 out_add = sess.graph.get_tensor_by_name('out_add:0')
     8 
     9 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 5})
    10 print(ret)
  • 相关阅读:
    文件上传漏洞之js验证
    文件上传漏洞靶机upload-labs(1到10)
    URI/URL/URN都是什么
    解压jdk报错gzip: stdin: not in gzip format
    burpsuite常见问题
    C/C++字符串反转的N种方法
    转 二叉树之Java实现二叉树基本操作
    MySQL 面试基础
    转 MySQL中的行级锁,表级锁,页级锁
    MySQL问题排查工具介绍
  • 原文地址:https://www.cnblogs.com/vsignsoft/p/14000250.html
Copyright © 2011-2022 走看看