zoukankan      html  css  js  c++  java
  • 5.1 Tensorflow:图与模型的加载与存储

    前言

    自己学Tensorflow,现在看的书是《TensorFlow技术解析与实战》,不得不说这书前面的部分有点坑,后面的还不清楚.图与模型的加载写的不清楚,书上的代码还不能运行=- =,真是BI….咳咳.之后还是开始了查文档,翻博客的填坑之旅
    ,以下为学习总结.

    快速应用

    存储与加载,简单示例

    # 一般而言我们是构建模型之后,session运行,但是这次不同之处在于我们是构件好之后存储了模型
    # 然后在session中加载存储好的模型,再运行
    import tensorflow as tf
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    # 声明两个变量
    v1 = tf.Variable(tf.random_normal([1, 2]), name='v1')
    v2 = tf.Variable(tf.random_normal([2, 3]), name='v2')
    init_op = tf.global_variables_initializer() # 初始化全部变量
    # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
    saver = tf.train.Saver()
    # 只存储图
    if not os.path.exists('save/model.meta'):
        saver.export_meta_graph('save/model.meta')
    
    
    print()
    with tf.Session() as sess:
        sess.run(init_op)
        print('v1:', sess.run(v1)) # 打印v1、v2的值一会读取之后对比
        print('v2:', sess.run(v2))
        saver_path = saver.save(sess, 'save/model.ckpt')  # 将模型保存到save/model.ckpt文件
        print('Model saved in file:', saver_path)
    
    print()
    with tf.Session() as sess:
        saver.restore(sess, 'save/model.ckpt') # 即将固化到硬盘中的模型从保存路径再读取出来,这样就可以直接使用之前训练好,或者训练到某一阶段的的模型了
        print('v1:', sess.run(v1)) # 打印v1、v2的值和之前的进行对比
        print('v2:', sess.run(v2))
        print('Model Restored')
    
    print()
    # 只加载图,
    saver = tf.train.import_meta_graph('save/model.ckpt.meta')
    with tf.Session() as sess:
        saver.restore(sess, 'save/model.ckpt')
        # 通过张量的名称来获取张量,也可以直接运行新的张量
        print('v1:', sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
        print('v2:', sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))

    运行结果:

    
    v1: [[-0.78213912 -0.72646964]]
    v2: [[-0.36301413 -0.99892306  0.21593148]
     [-1.09692276 -0.06931346  0.19474344]]
    Model saved in file: save/model.ckpt
    
    v1: [[-0.78213912 -0.72646964]]
    v2: [[-0.36301413 -0.99892306  0.21593148]
     [-1.09692276 -0.06931346  0.19474344]]
    Model Restored
    
    v1: [[-0.78213912 -0.72646964]]
    v2: [[-0.36301413 -0.99892306  0.21593148]
     [-1.09692276 -0.06931346  0.19474344]]
    

    构建模型后直接运行的结果,与加载存储的模型,加载存储的图,并哪找张量的名称获取张量并运行的结果是一致的

    存储的文件

    保存的文件

    tf.train.Saver与存储文件的讲解

    核心定义

    主要类:tf.train.Saver类负责保存和还原神经网络
    自动保存为三个文件:模型文件列表checkpoint,计算图结构model.ckpt.meta,每个变量的取值model.ckpt。其中前两个自动生成。
    加载持久化图:通过tf.train.import_meta_graph(“save/model.ckpt.meta”)加载持久化的图

    存储文件的讲解

    这段代码中,通过saver.save函数将TensorFlow模型保存到了model/model.ckpt文件中,这里代码中指定路径为”save/model.ckpt”,也就是保存到了当前程序所在文件夹里面的save文件夹中。

    TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

    checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在
    checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState
    Protocol Buffer.

    model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
    TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef
    Protocol Buffer定义的。MetaGraphDef
    中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef
    信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

    model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice
    Protocol
    Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,请自查。

    保存图与模型进阶

    按迭代次数保存

    # 在1000次迭代时存储
    saver.save(sess, 'my_test_model',global_step=1000)

    运行结果:

    my_test_model-1000.index
    my_test_model-1000.meta
    my_test_model-1000.data-00000-of-00001
    checkpoint

    按时间保存

    #saves a model every 2 hours and maximum 4 latest models are saved.
    saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

    更详细的解释

    其实更详细的解释就在源码之中,这些英语还是简单,我相信以大家的水平应该都能看得懂。就不侮辱大家的智商。

      def __init__(self,
                   var_list=None,
                   reshape=False,
                   sharded=False,
                   max_to_keep=5,
                   keep_checkpoint_every_n_hours=10000.0,
                   # 默认时间是一万小时,有趣
                   # 但我们只争朝夕
                   name=None,
                   restore_sequentially=False,
                   saver_def=None,
                   builder=None,
                   defer_build=False,
                   allow_empty=False,
                   write_version=saver_pb2.SaverDef.V2,
                   pad_step_number=False,
                   save_relative_paths=False):
        """Creates a `Saver`.
    
        The constructor adds ops to save and restore variables.
    
        `var_list` specifies the variables that will be saved and restored. It can
        be passed as a `dict` or a list:
    
        * A `dict` of names to variables: The keys are the names that will be
          used to save or restore the variables in the checkpoint files.
        * A list of variables: The variables will be keyed with their op name in
          the checkpoint files.
    
        For example:
    
        ```python
        v1 = tf.Variable(..., name='v1')
        v2 = tf.Variable(..., name='v2')
    
        # Pass the variables as a dict:
        saver = tf.train.Saver({'v1': v1, 'v2': v2})
    
        # Or pass them as a list.
        saver = tf.train.Saver([v1, v2])
        # Passing a list is equivalent to passing a dict with the variable op names
        # as keys:
        saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
        ```
    
        The optional `reshape` argument, if `True`, allows restoring a variable from
        a save file where the variable had a different shape, but the same number
        of elements and type.  This is useful if you have reshaped a variable and
        want to reload it from an older checkpoint.
    
        The optional `sharded` argument, if `True`, instructs the saver to shard
        checkpoints per device.
    
        Args:
          var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
            names to `SaveableObject`s. If `None`, defaults to the list of all
            saveable objects.
          reshape: If `True`, allows restoring parameters from a checkpoint
            where the variables have a different shape.
          sharded: If `True`, shard the checkpoints, one per device.
          max_to_keep: Maximum number of recent checkpoints to keep.
            Defaults to 5.
          keep_checkpoint_every_n_hours: How often to keep checkpoints.
            Defaults to 10,000 hours.
          name: String.  Optional name to use as a prefix when adding operations.
          restore_sequentially: A `Bool`, which if true, causes restore of different
            variables to happen sequentially within each device.  This can lower
            memory usage when restoring very large models.
          saver_def: Optional `SaverDef` proto to use instead of running the
            builder. This is only useful for specialty code that wants to recreate
            a `Saver` object for a previously built `Graph` that had a `Saver`.
            The `saver_def` proto should be the one returned by the
            `as_saver_def()` call of the `Saver` that was created for that `Graph`.
          builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
            Defaults to `BaseSaverBuilder()`.
          defer_build: If `True`, defer adding the save and restore ops to the
            `build()` call. In that case `build()` should be called before
            finalizing the graph or using the saver.
          allow_empty: If `False` (default) raise an error if there are no
            variables in the graph. Otherwise, construct the saver anyway and make
            it a no-op.
          write_version: controls what format to use when saving checkpoints.  It
            also affects certain filepath matching logic.  The V2 format is the
            recommended choice: it is much more optimized than V1 in terms of
            memory required and latency incurred during restore.  Regardless of
            this flag, the Saver is able to restore from both V2 and V1 checkpoints.
          pad_step_number: if True, pads the global step number in the checkpoint
            filepaths to some fixed width (8 by default).  This is turned off by
            default.
          save_relative_paths: If `True`, will write relative paths to the
            checkpoint state file. This is needed if the user wants to copy the
            checkpoint directory and reload from the copied directory.
    
        Raises:
          TypeError: If `var_list` is invalid.
          ValueError: If any of the keys or values in `var_list` are not unique.
        """
  • 相关阅读:
    matlab 函数库
    阿甘的珠宝 大数据博弈综合应用 SG函数 + 最后取为输或赢
    hdu 1536 博弈 SG函数(dfs)
    hdu 1907 John / 2509 Be the Winner 博弈 最后取完者为输
    深入理解 Nim 博弈
    SG函数模板 hdu 1848/1847/1849/1850/1851
    初始博弈 hdu 1846 Brave Game
    乘数密码 扩展欧几里得求逆元
    68.最大k乘积问题 (15分)
    第一次作业
  • 原文地址:https://www.cnblogs.com/fonttian/p/9162799.html
Copyright © 2011-2022 走看看