zoukankan      html  css  js  c++  java
  • tensorflow学习之Saver保存读取

      目前不是很懂。。但主要意思是tf可以把一开始定义的参数,包括Weights和Biases保存到本地,然后再定义一个变量框架去加载(restore)这个参数,作为变量本身的参数进行后续的训练,具体如下:

      

    import numpy as np
    #Save to file
     W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name='weights')
     b = tf.Variable([[1,2,3]],dtype=tf.float32,name='biases')
    
     init= tf.global_variables_initializer()
    
     saver = tf.train.Saver()
    
     with tf.Session() as sess:
         sess.run(init)
         save_path = saver.save(sess,"my_net/save_net.ckpt")
         print("Save to path:", save_path)

    和代码同一目录下就出现了my_net这个文件夹,同时里面有了四个文件

    然后,开始restore该参数

    # restore variables
    #redefine the same shape and same type for your variables
    tf.reset_default_graph()
    W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
    b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases") 
    
    #not need init step
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,"my_net/save_net.ckpt")
        print("weights:", sess.run(W))
        print("biases:", sess.run(b))


    #
    INFO:tensorflow:Restoring parameters from my_net/save_net.ckpt
    weights: [[1. 2. 3.]
     [3. 4. 5.]]
    biases: [[1. 2. 3.]]

    可以看到把原来的weights和biases都加载了

    人生苦短,何不用python
  • 相关阅读:
    使用gunicorn部署flask项目
    加密算法详解
    elasticsearch安装
    elk下载链接
    mysql允许远程连接
    工作流源代码分析
    查看账户的访问token
    Kube-proxy组件
    创建服务账户/查询访问token
    K8s概念2
  • 原文地址:https://www.cnblogs.com/yqpy/p/11042034.html
Copyright © 2011-2022 走看看