tensorflow 框架下的Saver 功能,用以保存和读取运算数据
Saver 保存数据
代码
import tensorflow as tf
# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "my_net/save_net.ckpt")
print("Save to path: ", save_path)
生成文件截图
运行代码,会在当前目录下生成一个新文件夹my_net
,文件夹内容有
解释
-
这里我们定义了两个张量,2行3列的W,和1行3列的b。这里强调行列形状 ,原因是只有存储张量的形状和读取时张量形状相同,才能被读取成功。
-
并且这里的W和b都定义了
name
,name是读取时候对应变量的关键 --'weights'和'biases'。和张量符号W和b没什么关系。 -
定义文件扩展名为
ckpt
,因为官方是这样定义的。
Saver 读取数据
import tensorflow as tf
W = tf.Variable(tf.zeros([2,3]), dtype=tf.float32, name="weights")
b = tf.Variable(tf.zeros([1,3]), dtype=tf.float32, name="biases")
# not need init step
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "my_net/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
打印结果
可以看到,读取数据文件代码定义张量W和b为全0,经过Saver 读取处理后,张量的数值成为保存文件中的数值。
解释
-
用Saver从文件读取,然后把读到的张量自动赋值给name相同 的张量
-
注意在读取代码中,张量被定义,但是没有初始化环节(sess.run(init)这一步) ,因为读取文件中的张量已经被初始化过了,这里就不用了