zoukankan      html  css  js  c++  java
  • tf 模型保存

    tf用 tf.train.Saver类来实现神经网络模型的保存和读取。无论保存还是读取,都首先要创建saver对象。

    用saver对象的save方法保存模型

    保存的是所有变量

    save(
        sess,
        save_path,
        global_step=None,  
        latest_filename=None,
        meta_graph_suffix='meta',
        write_meta_graph=True,
        write_state=True
    )

    保存模型需要session,初始化变量

    用法示例

    import tensorflow as tf
    
    v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
    v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
    result = v1 + v2
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, "Model/model.ckpt", global_step=3)

    输出

    1. global_step 放在文件名后面,起个标记作用

    2. save方法输出4个文件

      // checkpoint 里面是一堆路径,model_checkpoint_path 记录了最新模型的路径,all_model_checkpoint_paths 记录了之前模型的路径

      // model.ckpt-3.data-00000-of-00001 存放的是模型参数

      // model.ckpt-3.meta 存放的是计算图

    3. 最多只能保存近5次模型,比如我们迭代100次,每次保存一下,最后只留下了最近的5次。

    用saver对象的restore方法加载模型

    加载的是所有变量,以name为准,假如保存的模型中有变量叫 a ,value是2,那么在加载后,即使重新建立变量a,并赋其他value,其value仍然是2

    restore(
        sess,
        save_path
    )

    加载模型需要session,不需要初始化变量

    用法示例(接前例)

    v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
    v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
    # v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22")           # Key v22 not found in checkpoint
    result = v1 + v2
    
    saver = tf.train.Saver()
    #
    with tf.Session() as sess:
        saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
        print(sess.run(result)) # [ 3.]

    1. 重新给 name为 v2的变量 赋值,其结果仍然是3,说明加载了之前的v2

    2. 新建name为 v22 的变量,报错, 在保存的模型中没找到v2 。说明寻找变量以name为准,不以变量名为准

    继续做如下尝试

    v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
    # v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
    v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22")           # Key v22 not found in checkpoint
    result = v1 + v3
    
    saver = tf.train.Saver()
    #
    with tf.Session() as sess:
        # sess.run(tf.global_variables_initializer())                     # Key v22 not found in checkpoint
        saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
        # sess.run(tf.global_variables_initializer())                       # Key v22 not found in checkpoint
        print(sess.run(result)) # [ 3.]

    1. 新建name为v22的变量v3,仍然报错,说明新的变量没有被接受

    2. 在加载模型前初始化v3,仍然报错,加载模型后初始化v3,仍然报错,这说明在加载的模型中不接受新的变量。

    继续尝试

    v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
    # v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
    v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22")           # Key v22 not found in checkpoint
    result = v1 + v3
    
    saver = tf.train.Saver()
    #
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())                     # Key v22 not found in checkpoint
        print(sess.run(v3))                                             # [7.]
        saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
        sess.run(tf.global_variables_initializer())                       # Key v22 not found in checkpoint
        print(sess.run(result)) # [ 3.]

    在加载模型前初始化变量,正确输出,但在加载后,报错,证实了我上面的说法,“不接受新的变量”

    总结:

    1. 模型加载加载的是所有变量,以name为准

    2. 模型加载后不接受任何新的变量

    3. 在加载模型时需要重新定义计算图上的所有节点,但是变量无需初始化

    加载计算图

    直接加载计算图就无需重新定义计算图上的节点

    用法示例

    saver = tf.train.import_meta_graph("Model/model.ckpt-3.meta")
    
    with tf.Session() as sess:
        saver.restore(sess, "./Model/model.ckpt-3") # 注意路径写法
        print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))     # [3.]
        # print(sess.run(sess.graph.get_tensor_by_name('add:0')))                 # [3.]

    重命名变量

    在加载模型时不接受新的变量,这会造成很多麻烦。

    为解决这个问题,加载模型时可以给变量重命名。

    用法示例

    u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
    u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
    result = u1 + u2
    
    # 若直接声明Saver类对象,会报错变量找不到
    # 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}
    # 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中
    saver = tf.train.Saver({"v1": u1, "v2": u2})
    
    with tf.Session() as sess:
        saver.restore(sess, "./Model/model.ckpt-3")
        print(sess.run(result)) # [ 3.]

    注意重命名格式  老变量的name: 新变量名

    参考资料:

    https://blog.csdn.net/marsjhao/article/details/72829635

    https://blog.csdn.net/shuzfan/article/details/79197432

  • 相关阅读:
    习题三 答案
    习题二 答案
    Python开发【第三篇】:Python基本数据类型
    习题四 答案
    第一个python程序-判断登陆用户名和密码是否正确
    BFPRT算法 查找第k小的数
    设计模式----单例模式
    设计模式----原型模式
    非本地跳转
    链接器如何使用静态库解析引用
  • 原文地址:https://www.cnblogs.com/yanshw/p/10571715.html
Copyright © 2011-2022 走看看