zoukankan      html  css  js  c++  java
  • 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更:

    由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe...好吧,caffe也没有这种python风格的设定...

    废话少说,导入包:

    1 import numpy as np
    2 import tensorflow as tf

    保存会话:

    1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32)
    2 b = tf.Variable([[1,2,3]],dtype=tf.float32)
    3 
    4 init = tf.global_variables_initializer()
    5 saver = tf.train.Saver() # <---------
    6 
    7 with tf.Session() as sess:
    8     sess.run(init)
    9     save_path = saver.save(sess,'./my_net/saver_net.ckpt') # <---------

    载入会话:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32)
    2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess: 
    7     saver.restore(sess,'./my_net/saver_net.ckpt') # <---------
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(b))

     输出如下:

    Weight:
     [[ 1.  2.  3.]
     [ 4.  5.  6.]]
    biases:
     [[ 1.  2.  3.]]
    

     载入会话会加载之前保存的变量,所以不需要tf.global_variables_initializer()激活本次变量了。

     再更:

    引入节点名称后,只要tf变量节点的名称一致,python变量名不一致也能完美继承,也就是说tf变量节点的名称识别权限大于python变量名

    详细的命名规则下节有介绍:『TensorFlow』第八弹_变量与命名空间_固有结界

    保存模型:

    1 W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name='W') # <------
    2 b = tf.Variable([[1,2,3]],dtype=tf.float32,name='b')         # <------
    3 
    4 init = tf.global_variables_initializer()
    5 saver = tf.train.Saver()
    6 
    7 with tf.Session() as sess:
    8     sess.run(init)
    9     save_path = saver.save(sess,'./my_net/saver_net.ckpt')

    W--’W‘,b--’b‘

    载入模型:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32') # <------
    2 b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32') # <------
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess:
    7     saver.restore(sess,'./my_net/saver_net.ckpt')
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(b))

    W,b

    结果报错

    载入模型:

    1 W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name='W') # <------
    2 a = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name='b') # <------
    3 
    4 saver = tf.train.Saver()
    5 
    6 with tf.Session() as sess:
    7     saver.restore(sess,'./my_net/saver_net.ckpt')
    8     print('Weight:
    ',sess.run(W))
    9     print('biases:
    ',sess.run(a))

    W-’W‘,a--’b'

    1 INFO:tensorflow:Restoring parameters from ./my_net/saver_net.ckpt
    2 Weight:
    3  [[ 1.  2.  3.]
    4  [ 4.  5.  6.]]
    5 biases:
    6  [[ 1.  2.  3.]]
  • 相关阅读:
    python中的归并排序
    使用委派取代继承
    ffmpeg 2.3版本号, 关于ffplay音视频同步的分析
    Java学习之路(转)
    怎样通过terminal得到AWS EC2 instance的ip
    让面试官对你“一见钟情”
    Html5 中获取镜像图像
    架构师速成6.3-设计开发思路
    【转】Android的权限permission
    【转】 Android 基于google Zxing实现对手机中的二维码进行扫描--不错
  • 原文地址:https://www.cnblogs.com/hellcat/p/6899683.html
Copyright © 2011-2022 走看看