zoukankan      html  css  js  c++  java
  • save——model模块保存和载入使用简单例子

    https://www.w3xue.com/exp/article/201812/10995.html

    =====1====实践模型存入
    import tensorflow as tf
    from tensorflow import saved_model as sm
    # 首先定义一个极其简单的计算图
    X = tf.placeholder(tf.float32, shape=(3, ))
    scale = tf.Variable([10, 11, 12], dtype=tf.float32)
    y = tf.multiply(X, scale)
    # 在会话中运行
    with tf.Session() as sess:
        sess.run(tf.initializers.global_variables())
        value = sess.run(y, feed_dict={X: [1., 2., 3.]})
        print(value)
        
        # 准备存储模型
        path = '/home/×××/tf_model/model_1'
        builder = sm.builder.SavedModelBuilder(path)
        
        # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
        X_TensorInfo = sm.utils.build_tensor_info(X)
        scale_TensorInfo = sm.utils.build_tensor_info(scale)
        y_TensorInfo = sm.utils.build_tensor_info(y)
        # 构建 SignatureDef protobuf
        SignatureDef = sm.signature_def_utils.build_signature_def(
                                    inputs={'input_1': X_TensorInfo, 'input_2': scale_TensorInfo},
                                    outputs={'output': y_TensorInfo},
                                    method_name='what'
        )
        # 将 graph 和变量等信息写入 MetaGraphDef protobuf
        # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,不在新地方将自定义的字符串忘记,可以使用预定义的这些值
        builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING], 
                                             signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef}
      ) 
     # 将 MetaGraphDef 写入磁盘
        builder.save()
    
    
    =====222===模型导入 import tensorflow as tf from tensorflow import saved_model as sm # 需要建立一个会话对象,将模型恢复到其中 with tf.Session() as sess: path = '/home/×××/tf_model/model_1' MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path) # 解析得到 SignatureDef protobuf SignatureDef_d = MetaGraphDef.signature_def SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS] # 解析得到 3 个变量对应的 TensorInfo protobuf X_TensorInfo = SignatureDef.inputs['input_1'] scale_TensorInfo = SignatureDef.inputs['input_2'] y_TensorInfo = SignatureDef.outputs['output'] # 解析得到具体 Tensor # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图 X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph) scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph) y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph) print(sess.run(scale)) print(sess.run(y, feed_dict={X: [3., 2., 1.]})) # 输出 [10. 11. 12.] [30. 22. 12.]

      

    ========11111======实践模型存入(无格式,代码没对齐)

    import tensorflow as tf
    from tensorflow import saved_model as sm


    # 首先定义一个极其简单的计算图
    X = tf.placeholder(tf.float32, shape=(3, ))
    scale = tf.Variable([10, 11, 12], dtype=tf.float32)
    y = tf.multiply(X, scale)

    # 在会话中运行
    with tf.Session() as sess:
    sess.run(tf.initializers.global_variables())
    value = sess.run(y, feed_dict={X: [1., 2., 3.]})
    print(value)

    # 准备存储模型
    path = '/home/×××/tf_model/model_1'
    builder = sm.builder.SavedModelBuilder(path)

    # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
    X_TensorInfo = sm.utils.build_tensor_info(X)
    scale_TensorInfo = sm.utils.build_tensor_info(scale)


    y_TensorInfo = sm.utils.build_tensor_info(y)

    # 构建 SignatureDef protobuf
    SignatureDef = sm.signature_def_utils.build_signature_def(
    inputs={'input_1': X_TensorInfo, 'input_2': scale_TensorInfo},
    outputs={'output': y_TensorInfo},
    method_name='what'
    )

    # 将 graph 和变量等信息写入 MetaGraphDef protobuf
    # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,不在新地方将自定义的字符串忘记,可以使用预定义的这些值
    builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING],
    signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef}
    )

     # 将 MetaGraphDef 写入磁盘
    builder.save()

    ==================222222===========================================

    =========模型导入

    这样我们就把模型整体存储到了磁盘中,而且我们将三个变量 X, scale, y 全部序列化后存储到了其中,所以恢复模型时便可以将他们完全解析出来:

    import tensorflow as tf
    from tensorflow import saved_model as sm


    # 需要建立一个会话对象,将模型恢复到其中
    with tf.Session() as sess:
    path = '/home/×××/tf_model/model_1'
    MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path)

    # 解析得到 SignatureDef protobuf
    SignatureDef_d = MetaGraphDef.signature_def
    SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS]

    # 解析得到 3 个变量对应的 TensorInfo protobuf
    X_TensorInfo = SignatureDef.inputs['input_1']
    scale_TensorInfo = SignatureDef.inputs['input_2']
    y_TensorInfo = SignatureDef.outputs['output']

    # 解析得到具体 Tensor
    # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
    X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph)
    scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph)
    y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph)

    print(sess.run(scale))
    print(sess.run(y, feed_dict={X: [3., 2., 1.]}))

    # 输出
    [10. 11. 12.]
    [30. 22. 12.]

    ============333333讲解=

    https://github.com/Jerryzhangzhao/DL_tensorflow/blob/master/save%20and%20restore%20model/use%20saved%20model/save_and_restore_by_savedmodelbuilder.py

  • 相关阅读:
    html笔记
    Git入门学习总结
    使用OpenSSH远程管理Linux服务器
    Linux 网卡驱动的安装
    vi的使用
    Linux下常用的数据恢复工具
    网络文件系统(NFS)的使用
    文件系统管理
    磁盘存储管理
    用户权限管理
  • 原文地址:https://www.cnblogs.com/zhangbojiangfeng/p/11759903.html
Copyright © 2011-2022 走看看