zoukankan      html  css  js  c++  java
  • 2.1TF模型持久化

    目前tf只能保存模型中的variable变量,整个模型还不能保存,版本1.x

    保存模型代码

    import tensorflow as tf
    import numpy as np
    
    # Save to file
    # remember to define the same dtype and shape when restore
    v1 = tf.Variable(tf.constant(1.0,shape=[1]),  name='v1')
    v2 = tf.Variable(tf.constant(2.0,shape=[1]),  name='v2')
    result=v1+v2
    
    # tf.initialize_all_variables() no long valid from
    # 2017-03-02 if using tensorflow >= 0.12
    if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
        init = tf.initialize_all_variables()
    else:
        init = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
       sess.run(init)
       save_path = saver.save(sess,"save_model/save_pp.ckpt")
       print("Save to path: ", save_path)

    文件结构如下

    还原模型代码

    ################################################
    # restore variables
    # redefine the same shape and same type for your variables
    v1 = tf.Variable(tf.constant(1.0,shape=[1]),  name='v1')
    v2 = tf.Variable(tf.constant(2.0,shape=[1]),  name='v2')
    result=v1+v2
    # not need init step
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, "./save_model/save_pp.ckpt")
        print("v:", sess.run(v1))
        print("result:", sess.run(result))

    报错信息

    未解决

  • 相关阅读:
    大小端判断
    引用计数
    STL_ALGORITHM_H
    书单一览
    GIT版本控制系统(二)
    JS随机数生成算法
    STL学习笔记--临时对象的产生与运用
    乱序优化与GCC的bug
    你的灯亮着吗?
    交换机和路由器
  • 原文地址:https://www.cnblogs.com/jackchen-Net/p/8119706.html
Copyright © 2011-2022 走看看