zoukankan      html  css  js  c++  java
  • Tensorflow Learning1 模型的保存和恢复


    CKPT->pb

    Demo

    解析

    tensor name 和 node name 的区别

    Pb 的恢复



    CKPT->pb

    tensorflow的模型保存有两种形式:

    1. ckpt:可以恢复图和变量,继续做训练

    2. pb : 将图序列化,变量成为固定的值,,只可以做inference;不能继续训练


    Demo


      1 def freeze_graph(input_checkpoint,output_graph):
      2 
      3     '''
      4     :param input_checkpoint:
      5     :param output_graph: PB模型保存路径
      6     :return
      7       void
      8     '''
      9 
     10     # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
     11     # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
     12 
     13     # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
     14     output_node_names = "InceptionV3/Logits/SpatialSqueeze" # 如果是多个输出节点,使用 ‘,’号隔开
     15 
     16     ############################     Step1: 从ckpt中恢复图:     #############################################
     17     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
     18     graph = tf.get_default_graph() # 获得默认的图, 可以省略
     19     input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图,可以省略
     20 
     21     with tf.Session() as sess: # 会使用默认的图 作为当前的图
     22         saver.restore(sess, input_checkpoint) #恢复图并得到数据
     23 
     24         ########################     Step2: 创建持久化对象,指定sess,图、以及输出的序列化节点信息    ##############
     25         output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
     26             sess=sess,
     27             input_graph_def=input_graph_def,# 等于:sess.graph_def
     28             output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
     29         #########################    Step3: 模型持久化   #######################################################
     30         with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
     31             f.write(output_graph_def.SerializeToString()) #序列化输出
     32         print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
     33         # for op in graph.get_operations():
     34 
     35         #     print(op.name, op.values())
     36 
     37 
     38 ########################### 调用方式 ################################
     39 # 输入ckpt模型路径
     40 input_checkpoint='models/model.ckpt-10000'
     41 # 输出pb模型的路径
     42 out_pb_path="models/pb/frozen_model.pb"
     43 # 调用freeze_graph将ckpt转为pb
     44 freeze_graph(input_checkpoint,out_pb_path)

    解析

    函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。

    freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

    在保存pb的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称;

    tensor name 和 node name 的区别

    node name 是 图 的节点,里面包含了很多操作和tensor

    tensor 是 node 里面的一个组成部分;

    以input 为例,“input:0”是张量的名称,而"input"表示的是节点的名称

    PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"input:0"


  • 相关阅读:
    吴裕雄 Bootstrap 前端框架开发——Bootstrap 字体图标(Glyphicons)
    Logical partitioning and virtualization in a heterogeneous architecture
    十条实用的jQuery代码片段
    十条实用的jQuery代码片段
    十条实用的jQuery代码片段
    C#比较dynamic和Dictionary性能
    C#比较dynamic和Dictionary性能
    C#比较dynamic和Dictionary性能
    分别使用 XHR、jQuery 和 Fetch 实现 AJAX
    分别使用 XHR、jQuery 和 Fetch 实现 AJAX
  • 原文地址:https://www.cnblogs.com/greentomlee/p/11494383.html
Copyright © 2011-2022 走看看