zoukankan      html  css  js  c++  java
  • 吴裕雄 python 神经网络——TensorFlow 滑动平均类的保存

    import tensorflow as tf
    
    v = tf.Variable(0, dtype=tf.float32, name="v")
    for variables in tf.global_variables(): 
        print(variables.name)
        
    ema = tf.train.ExponentialMovingAverage(0.99)
    maintain_averages_op = ema.apply(tf.global_variables())
    for variables in tf.global_variables(): 
        print(variables.name)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        sess.run(tf.assign(v, 10))
        sess.run(maintain_averages_op)
        # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
        saver.save(sess, "E:\Saved_model\model2.ckpt")
        print(sess.run([v, ema.average(v)]))

    v = tf.Variable(0, dtype=tf.float32, name="v")
    
    # 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
    saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
    with tf.Session() as sess:
        saver.restore(sess, "E:\Saved_model\model2.ckpt")
        print sess.run(v)

  • 相关阅读:
    cf Round 633
    Django学习手册
    Django学习手册
    Django学习手册
    Django学习手册
    Django学习手册
    Django学习手册
    ERROR CL .exe……错误
    DLL、lib等链接库文件的使用
    HTTP协议
  • 原文地址:https://www.cnblogs.com/tszr/p/10875117.html
Copyright © 2011-2022 走看看