zoukankan      html  css  js  c++  java
  • tensorflow学习之Saver保存读取

      目前不是很懂。。但主要意思是tf可以把一开始定义的参数,包括Weights和Biases保存到本地,然后再定义一个变量框架去加载(restore)这个参数,作为变量本身的参数进行后续的训练,具体如下:

      

    import numpy as np
    #Save to file
     W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name='weights')
     b = tf.Variable([[1,2,3]],dtype=tf.float32,name='biases')
    
     init= tf.global_variables_initializer()
    
     saver = tf.train.Saver()
    
     with tf.Session() as sess:
         sess.run(init)
         save_path = saver.save(sess,"my_net/save_net.ckpt")
         print("Save to path:", save_path)

    和代码同一目录下就出现了my_net这个文件夹,同时里面有了四个文件

    然后,开始restore该参数

    # restore variables
    #redefine the same shape and same type for your variables
    tf.reset_default_graph()
    W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
    b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases") 
    
    #not need init step
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,"my_net/save_net.ckpt")
        print("weights:", sess.run(W))
        print("biases:", sess.run(b))


    #
    INFO:tensorflow:Restoring parameters from my_net/save_net.ckpt
    weights: [[1. 2. 3.]
     [3. 4. 5.]]
    biases: [[1. 2. 3.]]

    可以看到把原来的weights和biases都加载了

    人生苦短,何不用python
  • 相关阅读:
    MYSQL的FOUND_ROWS()函数
    mysql连表查询
    mysql事务
    js正则表达式
    mysql关键字执行顺序
    spring aop xml中配置实例
    spring注入bean的五种方式
    【CSS】之选择器性能和规范
    【视频】之H.264
    【Javascript】之eval()
  • 原文地址:https://www.cnblogs.com/yqpy/p/11042034.html
Copyright © 2011-2022 走看看