zoukankan      html  css  js  c++  java
  • keras_10_回调函数 Callbacks

    1. 回调函数的使用

    • 回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 SequentialModel 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。

    2. keras支持的回调函数

    1. Callback

      • 用来组建新的回调函数的抽象基类。keras.callbacks.Callback()
      • 被回调函数作为参数的 logs 字典,它会含有于当前批量或训练轮相关数据的键。
      • 目前,Sequential 模型类的 .fit() 方法会在传入到回调函数的 logs 里面包含以下的数据:
        • on_epoch_end: 包括 accloss 的日志, 也可以选择性的包括 val_loss(如果在 fit 中启用验证),和 val_acc(如果启用验证和监测精确值)。
        • on_batch_begin: 包括 size 的日志,在当前批量内的样本数量。
        • on_batch_end: 包括 loss 的日志,也可以选择性的包括 acc(如果启用监测精确值)。
    2. BaseLogger

      • 会积累训练轮平均评估的回调函数。这个回调函数被自动应用到每一个 Keras 模型上面。
    3. TerminateOnNaN

      • 当遇到 NaN 损失会停止训练的回调函数。
    4. ProgbarLogger

      • 会把评估以标准输出打印的回调函数。(即控制台输出训练的进度条及中间信息)
    5. History

      • 把所有事件都记录到 History 对象的回调函数。这个回调函数被自动启用到每一个 Keras 模型。History 对象会被模型的 fit 方法返回。
    6. ModelCheckpoint

      • 在每个训练期之后保存模型。filepath 可以包括命名格式选项,可以由 epoch 的值和 logs 的键(由 on_epoch_end 参数传递)来填充。例如:如果 filepathweights.{epoch:02d}-{val_loss:.2f}.hdf5, 那么模型被保存的的文件名就会有训练轮数和验证损失。
    7. EarlyStopping

      • 当被监测的某指标不再提升,则停止训练。
    8. RemoteMonitor

      • 将事件数据流到服务器的回调函数。需要 requests 库。 事件被默认发送到 root +'/publish/epoch/end/'。 采用 HTTP POST ,其中的 data 参数是以 JSON 编码的事件数据字典。
      • 感觉这是给远程监控服务器训练过程的用户准备的功能
    9. LearningRateScheduler

      • 学习速率定时调度器
    10. TensorBoard

      • Tensorboard 基本可视化。TensorBoard 是由 Tensorflow 提供的一个可视化工具。这个回调函数为 Tensorboard 编写一个日志, 这样你可以可视化测试和训练的标准评估的动态图像, 也可以可视化模型中不同层的激活值直方图。如果你已经使用 pip 安装了 Tensorflow,你应该可以从命令行启动 Tensorflow:tensorboard --logdir=/full_path_to_your_logs
    11. ReduceLROnPlateau

      • 当标准评估已经停止(训练到了plateau)时,降低学习速率。当学习停止时,模型总是会受益于降低 2-10 倍的学习速率
      • 这个回调函数监测一个数据并且当这个数据在一定「有耐心」的训练轮之后还没有进步, 那么学习速率就会被降低
    12. CSVLogger

      • 把训练轮结果数据流到 csv 文件的回调函数。支持所有可以被作为字符串表示的值,包括 1D 可迭代数据,例如,np.ndarray。例如,

        csv_logger = CSVLogger('training.log')
        model.fit(X_train, Y_train, callbacks=[csv_logger])
        
    13. LambdaCallback

      • 在训练进行中创建简单,自定义的回调函数的回调函数。这个回调函数和匿名函数在合适的时间被创建。 需要注意的是回调函数要求位置型参数,如下:

        • on_epoch_beginon_epoch_end 要求两个位置型的参数: epoch, logs
        • on_batch_beginon_batch_end 要求两个位置型的参数: batch, logs
        • on_train_beginon_train_end 要求一个位置型的参数: logs
      • 参数

        • on_epoch_begin: 在每轮开始时被调用。
        • on_epoch_end: 在每轮结束时被调用。
        • on_batch_begin: 在每批开始时被调用。
        • on_batch_end: 在每批结束时被调用。
        • on_train_begin: 在模型训练开始时被调用。
        • on_train_end: 在模型训练结束时被调用。
      • 例如,

        # 在每一个批开始时,打印出批数。
        batch_print_callback = LambdaCallback(
            on_batch_begin=lambda batch,logs: print(batch))
        
        # 把训练轮损失数据 流到 JSON 格式的文件。
        # 文件的内容不是完美的 JSON 格式,但是时每一行都是 JSON 对象。
        import json
        json_log = open('loss_log.json', mode='wt', buffering=1)
        json_logging_callback = LambdaCallback(
            on_epoch_end=lambda epoch, logs: json_log.write(
                json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '
        '),
            on_train_end=lambda logs: json_log.close()
        )
        
        # 在完成模型训练之后,结束一些进程。
        processes = ...
        cleanup_callback = LambdaCallback(
            on_train_end=lambda logs: [
                p.terminate() for p in processes if p.is_alive()])
        
        model.fit(...,
                  callbacks=[batch_print_callback,
                             json_logging_callback,
                             cleanup_callback])
        

    3. 自定义回调函数

    1. 你可以通过扩展 keras.callbacks.Callback 基类来创建一个自定义的回调函数。 通过类的属性 self.model,回调函数可以获得它所联系的模型。

    2. 例1:在训练时,保存一个列表的批量损失值:

      class LossHistory(keras.callbacks.Callback):
          def on_train_begin(self, logs={}):
              self.losses = []
      
          def on_batch_end(self, batch, logs={}):
              self.losses.append(logs.get('loss'))
      
    3. 例2:记录损失历史

      class LossHistory(keras.callbacks.Callback):
          def on_train_begin(self, logs={}):
              self.losses = []
      
          def on_batch_end(self, batch, logs={}):
              self.losses.append(logs.get('loss'))
      
      model = Sequential()
      model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
      model.add(Activation('softmax'))
      model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
      
      history = LossHistory()
      model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])
      
      print(history.losses)
      # 输出
      '''
      [0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]
      '''
      
    4. 例3:模型检查点

      from keras.callbacks import ModelCheckpoint
      
      model = Sequential()
      model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
      model.add(Activation('softmax'))
      model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
      
      '''
      如果验证损失下降, 那么在每个训练轮(epoch)之后保存模型。
      '''
      checkpointer = ModelCheckpoint(filepath='/tmp/weights.hdf5', verbose=1, save_best_only=True)
      model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])
      
  • 相关阅读:
    Chamfer Distance--倒角距离
    javax.net.ssl.SSLHandshakeException: Received fatal alert: handshake_failure
    mysql单个索引和联合索引的区别
    鸽一下
    笔记:关于 INT1 INT0 中断说明记录 (2020-07-16)[85.22%]
    使用 Git 管理 KiCad EDA 项目文件 [2020-06-28][26.77%]
    从单片机基础到程序框架 2019版(2020-07-04)[12.66%]
    KiCad Pcbnew 中现代工具箱 (2020-06-24)[98.33%]
    【营养研究一】鸡蛋和牛奶的营养对比 (2020-06-23)[95.89%]
    git 忽略上传指定文件 命令
  • 原文地址:https://www.cnblogs.com/LS1314/p/10380649.html
Copyright © 2011-2022 走看看