zoukankan      html  css  js  c++  java
  • 2.1TF模型持久化

    目前tf只能保存模型中的variable变量,整个模型还不能保存,版本1.x

    保存模型代码

    import tensorflow as tf
    import numpy as np
    
    # Save to file
    # remember to define the same dtype and shape when restore
    v1 = tf.Variable(tf.constant(1.0,shape=[1]),  name='v1')
    v2 = tf.Variable(tf.constant(2.0,shape=[1]),  name='v2')
    result=v1+v2
    
    # tf.initialize_all_variables() no long valid from
    # 2017-03-02 if using tensorflow >= 0.12
    if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
        init = tf.initialize_all_variables()
    else:
        init = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
       sess.run(init)
       save_path = saver.save(sess,"save_model/save_pp.ckpt")
       print("Save to path: ", save_path)

    文件结构如下

    还原模型代码

    ################################################
    # restore variables
    # redefine the same shape and same type for your variables
    v1 = tf.Variable(tf.constant(1.0,shape=[1]),  name='v1')
    v2 = tf.Variable(tf.constant(2.0,shape=[1]),  name='v2')
    result=v1+v2
    # not need init step
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, "./save_model/save_pp.ckpt")
        print("v:", sess.run(v1))
        print("result:", sess.run(result))

    报错信息

    未解决

  • 相关阅读:
    get与post的区别
    shell脚本之变量替换
    Oracle sql性能优化
    HTTP协议报头
    Oracle查看表空间和删除表空间
    shell脚本之cat和wc命令
    java设计模式之单例模式
    Wireshark基本介绍和学习TCP三次握手转
    wrong number of arguments (1 for 2)
    PHP生成.url文件 网站常用的保存到桌面功能
  • 原文地址:https://www.cnblogs.com/jackchen-Net/p/8119706.html
Copyright © 2011-2022 走看看