zoukankan      html  css  js  c++  java
  • Callback API

    Callback API

    用于跟踪epoch期间各种状态的回调函数。主要有6个类:

    1. mxnet.callback.module_checkpoint(modprefixperiod=1save_optimizer_states=False)

    [source]

    参数:

    • mod:BaseModule的子类。需要做checkpoint的module
    • prefix:字符串,该checkpoint文件的前缀
    • period:在做checkpoint之前需要等多少个epoch,默认为1
    • save_optimizer_states:布尔型,表明是否保存优化器状态用于继续训练

    返回:

    • callback:callback函数,可被作为iter_end_callback参数传递到fit函数里。

    2. mxnet.callback.do_checkpoint(prefixperiod=1)

    这个callback函数用于每隔几个epoch来保存以下模型checkpoint,每个checkpoint由几个binary files组成:一个模型描述文件和一个参数(权重和偏置)文件。模型描述文件名字为prefix-symbol.json,参数文件名字为prefix-epoch_number.params

    参数:

    • prefix:同上
    • period:整型,可选。几个epoch来保存一次。默认为1

    返回:

    • callback:一个callback函数,可被作为epoch_end_callback参数传递到fit函数里。
    >>> module.fit(iterator, num_epoch=n_epoch,
    ... epoch_end_callback  = mx.callback.do_checkpoint("mymodel", 1))
    Start training with [cpu(0)] Epoch[0] Resetting Data Iterator Epoch[0] Time cost
    =0.100 Saved checkpoint to "mymodel-0001.params" Epoch[1] Resetting Data Iterator Epoch[1] Time cost=0.060 Saved checkpoint to "mymodel-0002.params"

    3. mxnet.callback.log_train_metric(periodauto_reset=False)

    callback函数用于每隔几个周期记录训练打印结果

    参数:

    • period:整型,打印多少个batch的训练结果
    • auto_reset:布尔型,每次打印后重置评估函数

    返回:

    • callback:callback函数,可被作为iter_epoch_callback参数传递到fit函数里。

    4. class mxnet.callback.Speedometer(batch_sizefrequent=50auto_reset=True)

    周期性的打印训练速度和评价指标

    参数:

    • batch_size:整型
    • frequent:打印频率,默认每50个批量打印一次
    • auto_set:同上

    例子:

    >>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
    >>> module.fit(iterator, num_epoch=n_epoch,
    ... batch_end_callback=mx.callback.Speedometer(1, 10))
    Epoch[0] Batch [
    10] Speed: 1910.41 samples/sec Train-accuracy=0.200000 Epoch[0] Batch [20] Speed: 1764.83 samples/sec Train-accuracy=0.400000 Epoch[0] Batch [30] Speed: 1740.59 samples/sec Train-accuracy=0.500000

    5. class mxnet.callback.ProgressBar(totallength=80)

    [source]

    呈现一个进度条,表明每个epoch内批量的进度。

    参数:

    • total:每个epoch中所有批量的数目
    • length:进度条的最大长度

    例子:

    >>> progress_bar = mx.callback.ProgressBar(total=2)
    >>> mod.fit(data, num_epoch=5, batch_end_callback=progress_bar)
    [========--------] 50.0%
    [================] 100.0%

    6. class mxnet.callback.LogValidationMetricsCallback

    打印出一个epoch之后的评估结果

     

    整体的一个例子:train_mnist.py:用到了第2个和第4个类:

        model.fit(train,
                  begin_epoch=args.load_epoch if args.load_epoch else 0,
                  num_epoch=args.num_epochs,
                  eval_data=val,
                  eval_metric=eval_metrics,
                  kvstore=kv,
                  optimizer=args.optimizer,
                  optimizer_params=optimizer_params,
                  initializer=initializer,
                  arg_params=arg_params,
                  aux_params=aux_params,
                  batch_end_callback=[mx.callback.Speedometer(args.batch_size, args.disp_batches)],        # 每过多少个batch打印一下
                  epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix , period=args.save_period),      # 每过多少period保存模型
                  allow_missing=True,
                  monitor=monitor)
  • 相关阅读:
    绑定方式开始服务&调用服务的方法
    采用服务窃听电话示例
    后台服务运行示例
    Android短信监听器——示例
    利用广播实现ip拨号——示例
    Android图片的合成示例
    IIS 7.5 发布Web 网站步骤
    C# 中 多线程同步退出方案 CancellationTokenSource
    UML 类图常用表示方法.
    Socket Receive 避免 Blocking
  • 原文地址:https://www.cnblogs.com/king-lps/p/13060915.html
Copyright © 2011-2022 走看看