zoukankan      html  css  js  c++  java
  • tensorflow保存读取-【老鱼学tensorflow】

    当我们对模型进行了训练后,就需要把模型保存起来,便于在预测时直接用已经训练好的模型进行预测。

    保存模型的权重和偏置值

    假设我们已经训练好了模型,其中有关于weights和biases的值,例如:

    import tensorflow as tf
    # 保存到文件
    W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
    b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
    

    然后我们初始化这些变量的值,假装是训练后被设置上的值:

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    

    最后进行保存:

    # 创建saver
    saver = tf.train.Saver()
    save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
    print("保存的路径为:", save_path)
    

    这样在打印出:

    保存的路径为: D:/todel/python/saver/save_net.ckpt
    

    在那个目录下,我们看到:

    这样,这些训练后的参数就被保存起来了。

    完整的保存参数的代码为:

    import tensorflow as tf
    # 保存到文件
    W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
    b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')
    
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    
    # 创建saver
    saver = tf.train.Saver()
    save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
    print("保存的路径为:", save_path)
    
    

    恢复模型的权重和偏置值

    在我们训练好模型并把训练后的权重和偏置值保存了之后,当我们需要进行预测时,只要读取这个已经保存好的权重和偏置值就可以进行预测了。
    当然,这里的模型结构还是需要进行创建的,因为我们保存的仅仅是权重值和偏置值。

    首先定义要恢复的权重和偏置值的结构:

    import tensorflow as tf
    import numpy as np
    # 定义权重和偏置值的结构,但其中的数值随便填
    W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
    b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
    
    

    注意:其中的name要跟之前保存时一致。

    然后进行加载:

    saver = tf.train.Saver()
    sess = tf.Session()
    # 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
    saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))
    
    

    这样输出为:

    weights: [[ 1.  2.  3.]
     [ 3.  4.  5.]]
    biases: [[ 1.  2.  3.]]
    

    就是前面我们保存的内容被恢复出来了。

    完整的恢复代码为:

    import tensorflow as tf
    import numpy as np
    # 定义权重和偏置值的结构,但其中的数值随便填
    W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
    b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
    
    saver = tf.train.Saver()
    sess = tf.Session()
    # 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
    saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))
    
    
  • 相关阅读:
    误差可视化小结
    快速排序算法
    解决堆损坏的一点心得
    合并两个有序数组
    nginx安装
    Spark官方3 ---------Spark Streaming编程指南(1.5.0)
    【译】Yarn上常驻Spark-Streaming程序调优
    【Kafka】操作命令
    【Kafka】
    Spark组件
  • 原文地址:https://www.cnblogs.com/dreampursuer/p/8052541.html
Copyright © 2011-2022 走看看