当我们对模型进行了训练后,就需要把模型保存起来,便于在预测时直接用已经训练好的模型进行预测。
保存模型的权重和偏置值
假设我们已经训练好了模型,其中有关于weights和biases的值,例如:
import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
然后我们初始化这些变量的值,假装是训练后被设置上的值:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
最后进行保存:
# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)
这样在打印出:
保存的路径为: D:/todel/python/saver/save_net.ckpt
在那个目录下,我们看到:
这样,这些训练后的参数就被保存起来了。
完整的保存参数的代码为:
import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)
恢复模型的权重和偏置值
在我们训练好模型并把训练后的权重和偏置值保存了之后,当我们需要进行预测时,只要读取这个已经保存好的权重和偏置值就可以进行预测了。
当然,这里的模型结构还是需要进行创建的,因为我们保存的仅仅是权重值和偏置值。
首先定义要恢复的权重和偏置值的结构:
import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
注意:其中的name要跟之前保存时一致。
然后进行加载:
saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
这样输出为:
weights: [[ 1. 2. 3.]
[ 3. 4. 5.]]
biases: [[ 1. 2. 3.]]
就是前面我们保存的内容被恢复出来了。
完整的恢复代码为:
import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))