zoukankan      html  css  js  c++  java
  • Saver 保存与读取

    tensorflow 框架下的Saver 功能,用以保存和读取运算数据

    Saver 保存数据

    代码

    import tensorflow as tf
    
    # Save to file
    #remember to define the same dtype and shape when restore
    W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
    b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
    
    
    init = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
       sess.run(init)
       save_path = saver.save(sess, "my_net/save_net.ckpt")
       print("Save to path: ", save_path)
    

    生成文件截图

    运行代码,会在当前目录下生成一个新文件夹my_net ,文件夹内容有

    解释

    1. 这里我们定义了两个张量,2行3列的W,和1行3列的b。这里强调行列形状 ,原因是只有存储张量的形状和读取时张量形状相同,才能被读取成功。

    2. 并且这里的W和b都定义了name ,name是读取时候对应变量的关键 --'weights'和'biases'。和张量符号W和b没什么关系。

    3. 定义文件扩展名为ckpt ,因为官方是这样定义的。

    Saver 读取数据

    import tensorflow as tf
    
    W = tf.Variable(tf.zeros([2,3]), dtype=tf.float32, name="weights")
    b = tf.Variable(tf.zeros([1,3]), dtype=tf.float32, name="biases")
    
    # not need init step
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, "my_net/save_net.ckpt")
        print("weights:", sess.run(W))
        print("biases:", sess.run(b))
    

    打印结果

    可以看到,读取数据文件代码定义张量W和b为全0,经过Saver 读取处理后,张量的数值成为保存文件中的数值。

    解释

    1. 用Saver从文件读取,然后把读到的张量自动赋值给name相同 的张量

    2. 注意在读取代码中,张量被定义,但是没有初始化环节(sess.run(init)这一步) ,因为读取文件中的张量已经被初始化过了,这里就不用了

  • 相关阅读:
    用python执行Linux命令
    ls用法
    frigate_TUNNEL
    Python读写Excel文件的实例
    python操作Excel读写--使用xlrd
    iptables详解
    IPy过滤
    python 类中__call__内置函数的使用
    python 类中__init__函数的使用
    超继承
  • 原文地址:https://www.cnblogs.com/maskerk/p/9984179.html
Copyright © 2011-2022 走看看