zoukankan      html  css  js  c++  java
  • checkpoint文件

    tensorflow saver和checkpoint总结_gzj_1101的专栏-程序员宅基地_checkpoint 深度学习 - 程序员宅基地 (cxyzjd.com)

    最近在做深度学习相关实验,经常要用到别人的预训练模型,有时候常常不知道怎么使用,因此这篇博客将专门做一个总结。

    1 Tensorflow 模型文件

    checkpoint
    model.ckpt-200.data-00000-of-00001
    model.ckpt-200.index
    model.ckpt-200.meta

    1.1 meta文件

    model.ckpt-200.meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。一般而言网络结构是不会发生改变,所以可以只保存一个就行了。我们可以使用下面的代码只在第一次保存meta文件。

    saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

    并且还可以使用tf.train.import_meta_graph(‘model.ckpt-200.meta’)能够导入图结构。

    1.2 data文件

    model.ckpt-200.data-00000-of-00001

    数据文件,保存的是网络的权值,偏置,操作等等。

    1.3 index文件

    model.ckpt-200.index是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。

    Note: 以前的版本中tensorflow的model只保存一个文件中。

    2 保存和恢复Tensorflow模型

    2.1 保存模型

    tf.train.Saver 类别提供了保存和恢复模型的方法。tf.train.Saver 构造函数针对图中所有变量或指定列表的变量将 save 和 restore op 添加到图中。Saver 对象提供了运行这些 op 的方法,指定了写入或读取检查点文件的路径。
    一般而言,如果不指定任何参数,tf.train.Saver会保存所有的参数。下面是一个简单的例子,来自TensorFlow官网

    # Create some variables.
    v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
    v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
    
    inc_v1 = v1.assign(v1+1)
    dec_v2 = v2.assign(v2-1)
    
    # Add an op to initialize the variables.
    init_op = tf.global_variables_initializer()
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, initialize the variables, do some work, and save the
    # variables to disk.
    with tf.Session() as sess:
      sess.run(init_op)
      # Do some work with the model.
      inc_v1.op.run()
      dec_v2.op.run()
      # Save the variables to disk.
      save_path = saver.save(sess, "/tmp/model.ckpt")
      print("Model saved in path: %s" % save_path)

    最后会将v1和v2以及op都保存下来。但是如果你只想保存v1和v2,你可以这样写。

    tf.train.Saver(V1.values()+V2.values())

    2.2 恢复模型

    模型加载需要利用Saver.restore方法。可以加载固定参数,也可以加在所有参数。

    saver.restore(sess,model_path)

    具体的不一一介绍了,有兴趣的可以看一下参考的连接。

    2.3 加载模型限制

    pre-trained 模型常用来做迁移

    参考资料

    1.TensorFlow学习笔记(8)–网络模型的保存和读取

    2.Tensorflow加载预训练模型和保存模型

    3.TensorFlow, why there are 3 files after saving the model?

    学习,但是却存在一个限制,那就是网络的前一层必须是一致的,以vgg16为例,如果你利用前面几层提取特征,前面几层的网络必须得和vgg保持一致。而后面的网络参数是随机初始化的。

  • 相关阅读:
    charles安装以及手机端的设置
    ON DUPLICATE KEY UPDATE 用法与说明
    亿级流量架构之网关设计思路、常见网关对比
    灰度发布系统架构设计
    Jmeter 并发测试
    springboot --- Swagger UI初识
    TortoiseGIT 一直提示输入密码的解决方法!
    MySQL 5.6 参数详解
    LVS 轮询调度详解
    MongoDB 权限
  • 原文地址:https://www.cnblogs.com/yibeimingyue/p/15560122.html
Copyright © 2011-2022 走看看