zoukankan      html  css  js  c++  java
  • 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更:

    由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe...好吧,caffe也没有这种python风格的设定...

    废话少说,导入包:

    1 import numpy as np
    2 import tensorflow as tf

    保存会话:

    1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32)
    2 b = tf.Variable([[1,2,3]],dtype=tf.float32)
    3 
    4 init = tf.global_variables_initializer()
    5 saver = tf.train.Saver() # <---------
    6 
    7 with tf.Session() as sess:
    8     sess.run(init)
    9     save_path = saver.save(sess,'./my_net/saver_net.ckpt') # <---------

    载入会话:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32)
    2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess: 
    7     saver.restore(sess,'./my_net/saver_net.ckpt') # <---------
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(b))

     输出如下:

    Weight:
     [[ 1.  2.  3.]
     [ 4.  5.  6.]]
    biases:
     [[ 1.  2.  3.]]
    

     载入会话会加载之前保存的变量,所以不需要tf.global_variables_initializer()激活本次变量了。

     再更:

    引入节点名称后,只要tf变量节点的名称一致,python变量名不一致也能完美继承,也就是说tf变量节点的名称识别权限大于python变量名

    详细的命名规则下节有介绍:『TensorFlow』第八弹_变量与命名空间_固有结界

    保存模型:

    1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name='W') # <------
    2 b = tf.Variable([[1,2,3]],dtype=tf.float32,name='b')         # <------
    3 
    4 init = tf.global_variables_initializer()
    5 saver = tf.train.Saver()
    6 
    7 with tf.Session() as sess:
    8     sess.run(init)
    9     save_path = saver.save(sess,'./my_net/saver_net.ckpt')

    W--’W‘,b--’b‘

    载入模型:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32') # <------
    2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32') # <------
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess:
    7     saver.restore(sess,'./my_net/saver_net.ckpt')
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(b))

    W,b

    结果报错

    载入模型:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name='W') # <------
    2 a = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name='b') # <------
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess:
    7     saver.restore(sess,'./my_net/saver_net.ckpt')
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(a))

    W-’W‘,a--’b'

    1 INFO:tensorflow:Restoring parameters from ./my_net/saver_net.ckpt
    2 Weight:
    3  [[ 1.  2.  3.]
    4  [ 4.  5.  6.]]
    5 biases:
    6  [[ 1.  2.  3.]]
  • 相关阅读:
    Unique Binary Search Trees——LeetCode
    Binary Tree Inorder Traversal ——LeetCode
    Maximum Product Subarray——LeetCode
    Remove Linked List Elements——LeetCode
    Maximum Subarray——LeetCode
    Validate Binary Search Tree——LeetCode
    Swap Nodes in Pairs——LeetCode
    Find Minimum in Rotated Sorted Array——LeetCode
    Linked List Cycle——LeetCode
    VR AR MR
  • 原文地址:https://www.cnblogs.com/hellcat/p/6899683.html
Copyright © 2011-2022 走看看