zoukankan      html  css  js  c++  java
  • TensorFlow 便捷的实现机器学习 三

    TensorFlow 便捷的实现机器学习 三

     

     

    Overview


    Iris花瓣分类中,运行结果后,我们最后只是知道一个最终的结果:

    Accuracy: 0.933333
    Predictions: [1 2]
    

    我们并不能知道tensorflow执行的过程中发生了写什么。
    有一种方式是通过在训练过程中通过多次fit来一步一步的得到结果,但是这种方式会大大的影响执行效率。我们可以使用==tf.contrib.learn提供的Monitor API工具来实现监控。下面主要学习的时候启动Logging一级TensorBoard来对实现过程做一个监控。

    Enabling Logging with TensorFlow


    TensorFlow提供了五个等级的日记记录,分别是:

    1. DEBUG
    2. INFO
    3. WARN
    4. ERROR
    5. FATAL

    默认的情况下,TensorFlow主要设置为WARN等级。我们可以自行调整结果

    tf.logging.set_verbosity(tf.logging.INFO)

    这样子在运行程序的话,就会看到如下信息:

    INFO:tensorflow:Training steps [0,200)
    INFO:tensorflow:global_step/sec: 0
    INFO:tensorflow:Step 1: loss_1:0 = 1.48073
    INFO:tensorflow:training step 100, loss = 0.19847 (0.001 sec/batch).
    INFO:tensorflow:Step 101: loss_1:0 = 0.192693
    INFO:tensorflow:Step 200: loss_1:0 = 0.0958682
    INFO:tensorflow:training step 200, loss = 0.09587 (0.003 sec/batch).
    

    Configuring a ValidationMonitor for Streaming Evaluation


    记录训练损失有助于了解你的模型是否融合,但如果你想进一步了解培训期间发生了什么,你有该怎么办?tf.contrib.learn提供了几个高级的Monitor,您可以附加到您的合适的操作,以进一步跟踪指标和/或调试较低级别的TensorFlow操作在模型训练,主要包括:

    MonitorDescription
    CaptureVariable Saves a specified variable's values into a collection at every n steps of training
    PrintTensor Logs a specified tensor's values at every n steps of training
    SummarySaver Saves Summary protocol buffers for a given tensor using a SummaryWriter at every n steps of training
    ValidationMonitor Logs a specified set of evaluation metrics at every n steps of training, and, if desired, implements early stopping under certain conditions

    Evaluating Every N Steps

    在设置校验ValidationMonitor的时候,你也许想看看这个模型的泛化程度,这个时候你就可以通过设置(test_set.data and test_set.target),以及显示的频率来查看:

    validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        test_set.data,
        test_set.target,
        every_n_steps=50)
    

    然后将这个行代码放在实例化的classifier之前。ValidationMonitor依赖保存的检查点来执行评估操作,因此您需要修改分类器的实例化以添加包含save_checkpoints_secsRunConfig,它指定在训练期间在检查点保存之间应该经过多少秒。
    classifier可是如下设置:

    classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                                hidden_units=[10, 20, 10],
                                                n_classes=3,
                                                model_dir="/tmp/iris_model",
                                                config=tf.contrib.learn.RunConfig(
                                                    save_checkpoints_secs=1))
    

    然后,再讲设置好的validation_monitor放进去

    classifier.fit(x=training_set.data,
                   y=training_set.target,
                   steps=2000,
                   monitors=[validation_monitor])
    

    到此,就可以运行代码,然后就能看到:

    INFO:tensorflow:Validation (step 50): loss = 1.71139, global_step = 0, accuracy = 0.266667
    ...
    INFO:tensorflow:Validation (step 300): loss = 0.0714158, global_step = 268, accuracy = 0.966667
    ...
    INFO:tensorflow:Validation (step 1750): loss = 0.0574449, global_step = 1729, accuracy = 0.966667
    
    

    Customizing the Evaluation Metrics

    默认情况下,如果未指定评估指标,ValidationMonitor将同时记录损失和精确度,但您可以自定义每隔50个步骤运行的指标列表。tf.contrib.metrics模块为您可以与ValidationMonitor一起使用的分类模型提供各种其他度量功能,包括streaming_precision和streaming_recall。要指定要在每个评估传递中运行的确切指标,请向ValidationMonitor构造函数中添加一个指标参数。指标采用键/值对的dict,其中每个键是您要为该指标记录的名称,相应的值是计算它的函数。

    按照如下方式修改ValidationMonitor构造函数,以添加精度和回调的记录,以及精度(损失总是记录,不需要明确指定):

    validation_metrics = {"accuracy": tf.contrib.metrics.streaming_accuracy,
                          "precision": tf.contrib.metrics.streaming_precision,
                          "recall": tf.contrib.metrics.streaming_recall}
    validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        test_set.data,
        test_set.target,
        every_n_steps=50,
        metrics=validation_metrics)
    

    Early Stopping with ValidationMonitor

    注意,在上述对数输出中,通过步骤150,模型已经实现了1.0的精确度和召回率。这提出了一个问题,即模型训练是否可以从早期停止中受益。除了记录eval指标,ValidationMonitor使得在满足指定条件时容易实现提前停止,通过如下参数:

    ParamDescription
    early_stopping_metric Metric that triggers early stopping (e.g., loss or accuracy) under conditions specified in early_stopping_rounds and early_stopping_metric_minimize. Default is "loss".
    early_stopping_metric_minimize True if desired model behavior is to minimize the value of early_stopping_metric; False if desired model behavior is to maximize the value of early_stopping_metric. Default is True.
    early_stopping_rounds Sets a number of steps during which if the early_stopping_metric does not decrease (if early_stopping_metric_minimize is True) or increase (if early_stopping_metric_minimize is False), training will be stopped. Default is None, which means early stopping will never occur.

    我们可以做如下设置:

    validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
        test_set.data,
        test_set.target,
        every_n_steps=50,
        metrics=validation_metrics,
        early_stopping_metric="loss",
        early_stopping_metric_minimize=True,
        early_stopping_rounds=200)
    

    这样就会提前停止,而不需要到2000步,结果如下:

    ...
    INFO:tensorflow:Validation (step 1450): recall = 1.0, accuracy = 0.966667, global_step = 1431, precision = 1.0, loss = 0.0550445
    INFO:tensorflow:Stopping. Best step: 1150 with loss = 0.0506100878119.
    

    实际上,这里的训练在步骤1450停止,指示对于过去200个步骤,损失没有减少,并且总体来说,步骤1150针对测试数据集产生最小损失值。这表明通过减少步数来额外校准超参数可以进一步改善模型。

    Visualizing Log Data with TensorBoard


    通过阅读ValidationMonitor生成的日志,可以在训练期间提供大量有关模型性能的原始数据,但也可以查看此数据的可视化,以便进一步了解趋势,例如,精确度如何更改步数。您可以使用TensorBoard(与TensorFlow一起打包的单独程序)通过将logdir命令行参数设置为保存模型训练数据的目录(此处为/ tmp / iris_model)来绘制这样的图。在命令行上运行以下命令:

    $ tensorboard --logdir=/tmp/iris_model/
    Starting TensorBoard 22 on port 6006
    (You can navigate to http://0.0.0.0:6006)
    

    然后在浏览器中加载提供的URL(此处为http://0.0.0.0:6006)。就可以可视化查看结果了。

    reference

    [1] https://www.tensorflow.org/tutorials/monitors/

  • 相关阅读:
    继承
    对象与类
    反射
    I/O流
    字符串
    Map的entrySet()方法
    接口与内部类
    Git Usage Summary
    HTML(5)
    毕业设计:下载
  • 原文地址:https://www.cnblogs.com/flyu6/p/7691067.html
Copyright © 2011-2022 走看看