zoukankan      html  css  js  c++  java
  • 训练高级会话函数

    主训练逻辑

    我们将在cifar_train.py文件实现主要训练逻辑。在这里我们将使用一个新的会话函数,叫tf.train.MonitoredTrainingSession

    优点: 1、它自动的建立events文件、checkpoint文件,以记录重要的信息。 2、可以定义钩子函数,可以自定义每批次的训练信息,训练的限制等等

    注意:在这个里面我们需要添加一个全局步数,这个步数是每批次训练的时候进行+1计数,内部使用。

    代码如下:

    import tensorflow as tf
    import cifar_model
    import time
    from datetime import datetime
    
    
    
    def train():
        # 在图中进行训练
        with tf.Graph().as_default():
            # 定义全局步数,必须得使用这个,否则会出现StopCounterHook错误
            global_step = tf.contrib.framework.get_or_create_global_step()
    
            # 获取数据
            image, label, label_1 = cifar_model.input()
    
            # 通过模型进行类别预测
            y_logit = cifar_model.inference(image)
    
            # 计算损失
            loss = cifar_model.total_loss(label, y_logit)
    
            # 进行优化器减少损失
            train_op, accuracy = cifar_model.train(loss, label, y_logit, global_step)
    
            # 通过钩子定义模型输出
            class _LoggerHook(tf.train.SessionRunHook):
                """Logs loss and runtime."""
                def begin(self):
                    self._step = -1
                    self._start_time = time.time()
    
                def before_run(self, run_context):
                    self._step += 1
                    return tf.train.SessionRunArgs(loss, float(accuracy.eval()))  # Asks for loss value.
    
                def after_run(self, run_context, run_values):
                    if self._step % 10 == 0:
                        current_time = time.time()
                        duration = current_time - self._start_time
                        self._start_time = current_time
                        loss_value = run_values.results
                        examples_per_sec = 10 * 10 / duration
                        sec_per_batch = float(duration / 10)
    
                        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                      'sec/batch)')
                        print(format_str % (datetime.now(), self._step, loss_value,
                                            examples_per_sec, sec_per_batch))
    
            with tf.train.MonitoredTrainingSession(
                    checkpoint_dir="./cifartrain/train",
                    hooks=[tf.train.StopAtStepHook(last_step=500),# 定义执行的训练轮数也就是max_step,超过了就会报错
                           tf.train.NanTensorHook(loss),
                           _LoggerHook()],
                    config=tf.ConfigProto(
                        log_device_placement=False)) as mon_sess:
                while not mon_sess.should_stop():
                    mon_sess.run(train_op)
    
    
    def main(argv):
        train()
    
    
    if __name__ == "__main__":
        tf.app.run()

     

  • 相关阅读:
    NPOI单元格公式不刷新
    DIV+CSS HACK
    简答好用的邮件服务器hMailServer(转)
    C# 后台POST和GET 获取数据
    Quartz.Net1.0.2.3 配置记录
    ASP.NET自定义控件VS2012中添加失败(下列控件已成功添加到工具箱中,但未在活动设计器中启用)
    NPOI 1.2.5复制行(包括格式)
    Javascript中Null和Undefined的区别[转]
    测试流程(立项会)
    测试计划
  • 原文地址:https://www.cnblogs.com/alexzhang92/p/10070155.html
Copyright © 2011-2022 走看看