zoukankan      html  css  js  c++  java
  • CNN基础四:监测并控制训练过程的法宝——Keras回调函数和TensorBoard

    训练模型时,很多事情一开始都无法预测。比如之前我们为了找出迭代多少轮才能得到最佳验证损失,可能会先迭代100次,迭代完成后画出运行结果,发现在中间就开始过拟合了,于是又重新开始训练。

    类似的情况很多,于是我们想要实时监测训练动态,并能根据训练情况及时对模型采取一定的措施。Keras中的回调函数和tf的TensorBoard就是为此而生。

    Keras回调函数

    回调函数(callbacks)是在调用fit时传入模型的一个对象,它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态和性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或者改变模型的状态。也就是说,之前在训练模型的过程中,我们不知道模型的实时状态,因此为了更好的监测和控制模型的训练过程,我们派出了一个特派员——回调函数,它可以根据情况记录、反馈或者采取措施。我们熟悉的训练进度条和fit返回的history都是回调函数,只不过它俩因为太常用,所以被单独拎出来。

    fit和fit_generator函数都提供了callbacks接口。常用的回调函数有:

    • ModelCheckpoint(在每轮过后保存当前模型);
    • EarlyStopping(如果监控参数得不到改善就中断训练);
    • LearningRateScheduler(在训练过程中动态调整学习率);
    • ReduceLROnPlateau(如果验证表现得不到改善,可以用它降低学习率,跳出局部最小值);
    • CSVLogger(将每个epoch的结果写入CSV文件)。
    • 其他回调函数,也可以根据需要自行编写。

    应用示例:

    from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
    
    #fit提供callbacks接口,接收一个回调函数列表,可将任意个回调函数传入模型中
    callback_lists = []
    
    callback_lists.append(EarlyStopping(monitor = 'acc', #监控模型的验证精度
                                        patience = 1)) #如果精度在多于一轮的时间(即两轮)内不再改善,就中断训练
    
    callback_lists.append(ModelCheckpoint(filepath = 'my_model.h5', #目标文件的保存路径
                                          monitor = 'val_loss',  #监控验证损失
                                          save_best_only = True)) #只保存最佳模型
    
    callback_lists.append(ReduceLROnPlateau(monitor = 'val_loss',  #监控模型的验证损失
                                          factor = 0.1,  #触发时将学习率乘以系数0.1
                                          patience = 10) #若验证损失在10轮内都没有改善,则触发该回调函数
    
    #由于回调函数要监控验证损失和验证精度,所以在调用fit时需要传入validation_data
    model.fit(x, y, epochs = 10, batch_size = 32,
             callbacks = callbacks_list,
             validation_data = (x_val, y_val))
    

    TensorBoard:实时可视化工具

    TensorBoard是内置于TensorFlow中基于浏览器的可视化工具,安装TensorFlow时会自动安装这个工具。简单来说,它就是把训练过程数据写入文件,然后用浏览器查看的工具。在Keras中,它也被包装成一个回调函数。

    示例如下:

    #引入Tensorboard
    from keras.callbacks import TensorBoard
    
    #定义回调函数列表,现在只放一个简单的TensorBoard
    log_path = './logs' #指定TensorBoard读取的文件路径,可以新建一个
    callback_lists = [TensorBoard(log_dir=log_path, histogram_freq=1)]
    
    #模型调用fit时,通过回调函数接口传入
    model.fit(...inputs and parameters..., callbacks=callback_lists)
    

    为了在训练的过程中可视化各项指标,需要自己在终端启动TensorBoard。

    打开终端的方式有两种:一种是系统自带的终端cmd;另一种是在Anaconda Prompt终端。选择用哪种终端打开,根据当时安装tensorflow时用的终端方式。我试了下cmd,总是出错,但在Anaconda Prompt终端就能正常启动。

    启动方式:在终端输入 tensorboard --logdir=C:Users...logs (自己文件的路径),就会返回一行信息,包含了一个http网址。这个地址一般是不会改变的,在浏览器中输入提示的http地址,即可查看模型的训练过程和相关状态,如下图所示。

    Reference

    书籍:Python深度学习

  • 相关阅读:
    打开服务器的文档
    笔记
    centos6.5 编译openssl-1.1.1k
    搭建自己的低代码平台
    防火墙ACL配置自动化
    防火墙ACL配置自动化
    【树莓派】读取新大陆(newland)USB条码扫描器数据
    解决eclipse或sts闪退的办法(转)
    浅谈数据库迁移类项目功能测试的基本思路
    ATM取款机优化需求的用例设计
  • 原文地址:https://www.cnblogs.com/inchbyinch/p/11986629.html
Copyright © 2011-2022 走看看