zoukankan      html  css  js  c++  java
  • TensorFlow学习笔记(8)--网络模型的保存和读取【转】

    转自:http://blog.csdn.net/lwplwf/article/details/62419087

    之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。

    TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。


    下面代码给出了保存TensorFlow模型的方法:

    import tensorflow as tf
    
    # 声明两个变量
    v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
    v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
    init_op = tf.global_variables_initializer() # 初始化全部变量
    saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
    with tf.Session() as sess:
        sess.run(init_op)
        print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
        print("v2:", sess.run(v2))
        saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件
        print("Model saved in file:", saver_path)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    这段代码中,通过saver.save函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

    TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中实际会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

    • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
    • model.ckpt文件保存了TensorFlow程序中每一个变量的取值
    • checkpoint文件保存了一个目录下所有的模型文件列表

    这里写图片描述


    下面代码给出了加载TensorFlow模型的方法:

    可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

    import tensorflow as tf
    
    # 使用和保存模型代码中一样的方式来声明变量
    v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
    v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
    saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
    with tf.Session() as sess:
        saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
        print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
        print("v2:", sess.run(v2))
        print("Model Restored")
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    运行结果:

    v1: [[ 0.76705766  1.82217288]]
    v2: [[-0.98012197  1.2369734   0.5797025 ]
     [ 2.50458145  0.81897354  0.07858191]]
    Model Restored
    • 1
    • 2
    • 3
    • 4
    • 1
    • 2
    • 3
    • 4

    这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。 
    也就是说使用TensorFlow完成了一次模型的保存和读取的操作。



    如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:

    import tensorflow as tf
    # 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
    # 直接加载持久化的图
    saver = tf.train.import_meta_graph("save/model.ckpt.meta")
    with tf.Session() as sess:
        saver.restore(sess, "save/model.ckpt")
        # 通过张量的名称来获取张量
        print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    运行程序,输出:

    [[ 0.76705766  1.82217288]]
    • 1
    • 1

    有时可能只需要保存或者加载部分变量。 
    比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

    为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。

    …未完待续

  • 相关阅读:
    css文字和背景色渐变色
    雪碧图定位
    js操作链接url
    93服务器上获取json数据
    this的区别
    绩效项目总结
    【ASP.NET MVC 学习笔记】- 05 依赖注入工具Ninject
    【ASP.NET MVC 学习笔记】- 04 依赖注入(DI)
    【ASP.NET MVC 学习笔记】- 03 Razor语法
    【ASP.NET MVC 学习笔记】- 02 Attribute
  • 原文地址:https://www.cnblogs.com/seaspring/p/6766923.html
Copyright © 2011-2022 走看看