zoukankan      html  css  js  c++  java
  • 两种Tensorflow模型保存的方法

    在Tensorflow中,有两种保存模型的方法:一种是Checkpoint,另一种是Protobuf,也就是PB格式;

    一. Checkpoint方法:

       1.保存时使用方法:

                      tf.train.Saver()

                 生成四个文件:

           checkpoint                 检查点文件

                       model.ckpt.data-xxx 参数值

                       model.ckpt.index   各个参数

                       model.ckpt.meta   图的结构

       2.恢复时使用方法:

         saver.restore() :模型文件依赖Tensorflow,只能在其框架下使用,恢复模型之前需要定义下网络结构

                      saver=tf.train.import_meta_graph('./ckpt/mode..ckpt.meta')  :直接加载网络结构,不需要重新定义网络

    二. PB方法:

      1. 保存模型为PB文件(谷歌推荐),具有语言独立性,可独立运行,序列化的格式,任何语言可解析它,允许其他语言和框架读取,训练和迁移;模型变量是固定的,模型大小会大大减少,适合在手机端运行;

      2. 实现创建模型与使用模型的解耦,使得前向推导Inference代码统一;

      3. PB文件表示MetaGraph的protocol buffer格式的文件;

      4. GraphDef 不保存任何Variable信息,不能从graph_def 来构建图并恢复训练. 

           一般情况下,PB可直接生成;

           当然也可以从checkpoint文件中生成,代码如下:

     1 output_graph = os.path.join('./checkpoint/','frozen_graph.pb')
     2 input_checkpoint = os.path.join('./checkpoint/','model.ckpt-xxxxx')  #[xxxxxx为训练生成的step号]
     3 saver = tf.train.import_meta_graph(input_checkpoint+'.meta',clear_devices=True)
     4 graph = tf.get_default_graph()
     5 input_graph_def = graph.as_graph_def
     6 
     7 for op in graph.get_operations():
     8     print("checkpoint2pb",op.name,op.values())
     9 
    10 variable_names = [v.name for v in tf.trainable_variables()]
    11 pirnt("trainalbe_variables:",variable_names)
    12 
    13 output_node_name=['fc2/add']  #fc2/add 上面的列表里需要存在该操作
    14 
    15 with tf.Session() as sess:
    16     saver.restore(sess,input_checkpoint)
    17 
    18     output_graph_def = graph_util.convert_variables_to_constants(sess=sess,
    19                     input_graph_def = input_graph_def,
    20                     output_node_names = output_node_name)
    21 
    22     with tf.gfile.GFile(output_graph,"wb") as f:
    23         f.write(output_graph_def.SerializeToString())
    24 
    25 
    26     
    27  
    View Code

       

  • 相关阅读:
    导包路径
    django导入环境变量 Please specify Django project root directory
    替换django的user模型,mysql迁移表报错 django.db.migrations.exceptions.InconsistentMigrationHistory: Migration admin.0001_initial is applied before its dependen cy user.0001_initial on database 'default'.
    解决Chrome调试(debugger)
    check the manual that corresponds to your MySQL server version for the right syntax to use near 'order) values ('徐小波','XuXiaoB','男','1',' at line 1")
    MySQL命令(其三)
    MySQL操作命令(其二)
    MySQL命令(其一)
    [POJ2559]Largest Rectangle in a Histogram (栈)
    [HDU4864]Task (贪心)
  • 原文地址:https://www.cnblogs.com/jimchen1218/p/11696419.html
Copyright © 2011-2022 走看看