zoukankan      html  css  js  c++  java
  • 【转】tf.SessionRunHook使用方法

    原文地址:https://blog.csdn.net/mrr1ght/article/details/81011280 。本文有删减。

    tf.train.SessionRunHook()是一个类;用来定义Hooks;

    Hooks是什么,官方文档中关于training hooks的定义是:

    Hooks are tools that run in the process of training/evaluation of the model.

    Hooks是在模型训练/测试过程中的工具。Pytorch中也经常会有这个概念出现,其实也就跟keras里的callbacks一样,hook和callback都是在训练过程中执行特定的任务。

    例如判断是否需要停止训练的EarlyStopping;改变学习率的LearningRateScheduler,他们都有一个共性,就是在每个step开始/结束或者每个epoch开始/结束时需要执行某个操作。如每个epoch结束都保存一次checkpoint;每个epoch结束时都判断一次loss有没有下降,如果loss没有下降的轮数大于提取设定的阈值,就终止训练。当然以上的功能我们都可以自己完全重头实现。但是这些keras和tersorflow提供了更好的工具就是hook和callback,并且一些常用的功能都已经实现好了。说到底每个hook和callback都是按照固定格式定义了在每个step开始/结束要执行的操作,每个epoch开始/结束执行的操作。

    Hooks都是继承自父类tf.train.SessionRunHook(),首先看一下这个父类的定义源码;

    tf.train.SessionRunHook()定义

    tf.train.SessionRunHook()类定义在tensorflow/python/training/session_run_hook.py,类中每个函数的作用与什么时候调用都已加入函数注释中;

    class SessionRunHook(object):
      """Hook to extend calls to MonitoredSession.run()."""
     
      def begin(self):
        """再创建会话之前调用
        调用begin()时,default graph会被创建,
        可在此处向default graph增加新op,begin()调用后,default graph不能再被修改
        """
        pass
     
      def after_create_session(self, session, coord):  # pylint: disable=unused-argument
        """tf.Session被创建后调用
        调用后会指示所有的Hooks有一个新的会话被创建
        Args:
          session: A TensorFlow Session that has been created.
          coord: A Coordinator object which keeps track of all threads.
        """
        pass
     
      def before_run(self, run_context):  # pylint: disable=unused-argument
        """调用在每个sess.run()执行之前
        可以返回一个tf.train.SessRunArgs(op/tensor),在即将运行的会话中加入这些op/tensor;
        加入的op/tensor会和sess.run()中已定义的op/tensor合并,然后一起执行;
        Args:
          run_context: A `SessionRunContext` object.
        Returns:
          None or a `SessionRunArgs` object.
        """
        return None
      def after_run(self,
                    run_context,  # pylint: disable=unused-argument
                    run_values):  # pylint: disable=unused-argument
        """调用在每个sess.run()之后
        参数run_values是befor_run()中要求的op/tensor的返回值;
        可以调用run_context.qeruest_stop()用于停止迭代
        sess.run抛出任何异常after_run不会被调用
        Args:
          run_context: A `SessionRunContext` object.
          run_values: A SessionRunValues object.
        """
        pass
     
      def end(self, session):  # pylint: disable=unused-argument
        """在会话结束时调用
        end()常被用于Hook想要执行最后的操作,如保存最后一个checkpoint
        如果sess.run()抛出除了代表迭代结束的OutOfRange/StopIteration异常外,
        end()不会被调用
        Args:
          session: A TensorFlow Session that will be soon closed.
        """
        pass
    

    tf.train.SessionRunHook()类中定义的方法的参数run_context,run_values,run_args,包含sess.run()会话运行所需的一切信息,

    • run_context:类tf.train.SessRunContext的实例
    • run_values:类tf.train.SessRunValues的实例
    • run_args:类tf.train.SessRunArgs的实例.

    这三个类会在下面详细介绍

    tf.train.SessionRunHook()的使用

    (1)可以使用tf中已经预定义好的Hook,其都是tf.train.SessionRunHook()的子类;如

    • StopAtStepHook:设置用于停止迭代的max_step或num_step,两者只能设置其一
    • NanTensorHook:如果loss的值为Nan,则停止训练;
    • tensorflow中有许多预定义的Hook,想了解更多的同学可以去官方文档tf.train.下查看

    (2)也可用tf.train.SessionRunHook()定义自己的Hook,并重写类中的方法;然后把想要使用的Hook(预定义好的或者自己定义的)放到tf.train.MonitorTrainingSession()参数[Hook]列表中;

    关于tf.train.MonitorTrainingSession()参见tf.train.MonitoredTrainingSession()解析

    给一个定义自己Hook的栗子,来自cifar10

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""
     
      def begin(self):
        self._step = -1
        self._start_time = time.time()
     
      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.
     
      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time#duration持续的时间
          self._start_time = current_time
     
          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)
     
          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))
    

    SessRunContext/SessRunValues/SessRunArgs

    这三个类都服务于sess.run(),区别如下:

    • tf.train.SessRunContext和tf.train.SessRunArgs提供会话运行所需的信息,
    • tf.train.SessRunValues保存会话运行的结果

    (1) tf.train.SessRunArgs类
    提供给会话运行的参数,与sess.run()参数定义一样:
    fethes,feeds,option

    (2) tf.train.SessRunValues
    用于保存sess.run()的结果,其中resluts是sess.run()返回值中对应于SessRunArgs()的返回值,

    (3) tf.train.SessRunContext
    SessRunContext包含sess.run()所需的一切信息

    属性:

    • original_args:sess.run所需的参数,是一个tf.train.SessRunArgs实例
    • session:指定要运行的会话
    • stop_request:返回一个bool值,用于判断是否停止迭代;

    方法:

    equest_stop(): 设置_stop_request值为True

    cifar10 中的运用实例

    tf.train.SessionRunHook()和tf.train.MonitorTrainingSession()一般一起使用,下面是cifar10中的使用实例

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""
     
      def begin(self):
        self._step = -1
        self._start_time = time.time()
     
      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.
     
      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time#duration持续的时间
          self._start_time = current_time
     
          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)
     
          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))
     
                 #monitored 被监控的
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
    


    MARSGGBO原创





    2019-10-21 11:16:01



  • 相关阅读:
    NoSQL数据库 continue posting...
    CAP 理论
    Clojure Web 开发 (一)
    HttpClient 4.0.x Tips
    zZ Java中String和Byte[]之间的那些事
    使用nhibernate出现Could not find the dialect in the configuration
    eclipse导入项目出现Project has no default.properties file! Edit the project properties to set one.
    今天开通此博~
    美国白蛾入侵北京 GIS兵法破解危局
    HTML5 存取Json
  • 原文地址:https://www.cnblogs.com/marsggbo/p/11712616.html
Copyright © 2011-2022 走看看