zoukankan      html  css  js  c++  java
  • mxnet modle.fit源码解析

     model.fit源码分析

    首先来到module模块中,即https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/module,进入base_module.py中,我们便可以看到fit()的原型。

    class BaseModule(object):
        ################################################################################
        # High Level API
        ################################################################################
        def forward_backward(self, data_batch):
            """A convenient function that calls both ``forward`` and ``backward``."""
            self.forward(data_batch, is_train=True)
            self.backward()
    
        # 验证集评测
        def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
                  score_end_callback=None,
                  reset=True, epoch=0, sparse_row_id_fn=None):
            """Runs prediction on ``eval_data`` and evaluates the performance according to
            the given ``eval_metric``.
    
            Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
            to see an end-to-end use-case.
    
            Parameters
            ----------
            eval_data : DataIter
                Evaluation data to run prediction on.
            eval_metric : EvalMetric or list of EvalMetrics
                Evaluation metric to use.
            num_batch : int
                Number of batches to run. Defaults to ``None``, indicating run until the `DataIter`
                finishes.
            batch_end_callback : function
                Could also be a list of functions.
            reset : bool
                Defaults to ``True``. Indicates whether we should reset `eval_data` before starting
                evaluating.
            epoch : int
                Defaults to 0. For compatibility, this will be passed to callbacks (if any).
                During training, this will correspond to the training epoch number.
            sparse_row_id_fn : A callback function
                The function  takes `data_batch` as an input and returns a dict of
                str -> NDArray. The resulting dict is used for pulling row_sparse
                parameters from the kvstore, where the str key is the name of the param,
                and the value is the row id of the param to pull.
    
            Examples
            --------
            >>> # An example of using score for prediction.
            >>> # Evaluate accuracy on val_dataiter
            >>> metric = mx.metric.Accuracy()
            >>> mod.score(val_dataiter, metric)
            >>> mod.score(val_dataiter, ['mse', 'acc'])
            """
            assert self.binded and self.params_initialized
    
            # reset验证集
            if reset:
                eval_data.reset()
    
            if not isinstance(eval_metric, metric.EvalMetric):
                eval_metric = metric.create(eval_metric)
    
            eval_metric.reset()
            actual_num_batch = 0
    
            # 验证集batch获取
            for nbatch, eval_batch in enumerate(eval_data):
                if num_batch is not None and nbatch == num_batch:
                    break
                # 模型加载数据集
                self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
                # 前向传播
                self.forward(eval_batch, is_train=False)
                # 调用metric列表update函数
                if isinstance(eval_batch, list):
                    self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
                else:
                    self.update_metric(eval_metric, eval_batch.label)
    
                # batch结束回调
                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                actual_num_batch += 1
    
            # 验证集评测结束回调
            if score_end_callback:
                params = BatchEndParam(epoch=epoch,
                                       nbatch=actual_num_batch,
                                       eval_metric=eval_metric,
                                       locals=locals())
                for callback in _as_list(score_end_callback):
                    callback(params)
    
            # 返回metric列表结果name:value
            return eval_metric.get_name_value()
    
        def fit(self, train_data, eval_data=None, eval_metric='acc',
                epoch_end_callback=None, batch_end_callback=None, kvstore='local',
                optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
                eval_end_callback=None,
                eval_batch_end_callback=None, initializer=Uniform(0.01),
                arg_params=None, aux_params=None, allow_missing=False,
                force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
                validation_metric=None, monitor=None, sparse_row_id_fn=None):
            """Trains the module parameters.
    
            Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
            to see an end-to-end use-case.
    
            Parameters
            ----------
            train_data : DataIter
                训练集数据迭代器
            eval_data : DataIter
                如果不是'None',将用作验证集,并将评估每个时期之后的性能。
            eval_metric : str or EvalMetric
                默认是字符串'accuracy'.训练期间用来显示的绩效指标。
                其他可能的预定义指标是:'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
            epoch_end_callback : function or list of functions
                每个epoch结束时回调,参数 `epoch`, `symbol`, `arg_params`and `aux_params`
            batch_end_callback : function or list of function
                每个batch结束时回调,参数 `BatchEndParam`.
            kvstore : str or KVStore
                参数更新设备,默认值'local'.
                "device",GPU计算梯度更新权重
                "local",CPU更新
                "dist_device_sync",分布式训练
            optimizer : str or Optimizer
                优化器,默认值'sgd'.
            optimizer_params : dict
                优化器参数,默认值(('learning_rate', 0.01),)
            eval_end_callback : function or list of function
                evaluation全跑完回调
            eval_batch_end_callback : function or list of function
                evaluation一个batch跑完回调
            initializer : Initializer
                如果尚未初始化模块参数,则调用初始化程序来初始化它们
            arg_params : dict
                默认None, 值不为None,则替代initializer初始化参数
            aux_params : dict
                默认None, 值不为None,则替代initializer初始化参数
            allow_missing : bool
                默认False,是否允许丢失参数
                指示当arg_params和aux_params不为None时是否允许缺少参数。
                allow_missing=True,那么缺少的参数将通过initializer进行初始化。
            force_rebind : bool
                默认False
                如果已经绑定执行器,是否强制重新绑定执行器。
            force_init : bool
                默认False
                指示即使参数已经初始化也是否强制初始化。
            begin_epoch : int
                默认值0
                指示开始epoch。通常,如果从前一个训练阶段在Epoch[n]保存,重新训练则该值应为n+1
            num_epoch : int
                训练的epoch数量
            sparse_row_id_fn : A callback function
                The function  takes `data_batch` as an input and returns a dict of
                str -> NDArray. The resulting dict is used for pulling row_sparse
                parameters from the kvstore, where the str key is the name of the param,
                and the value is the row id of the param to pull.
    
            Examples
            --------
            >>> # An example of using fit for training.
            >>> # Assume training dataIter and validation dataIter are ready
            >>> # Assume loading a previously checkpointed model
            >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
            >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
            ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
            ...     arg_params=arg_params, aux_params=aux_params,
            ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
            """
            assert num_epoch is not None, 'please specify number of epochs'
    
            # 绑定训练集数据symbols name
            self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                      for_training=True, force_rebind=force_rebind)
            if monitor is not None:
                self.install_monitor(monitor)
            # 初始化权重参数,初始化策略参考以上的参数说明
            self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                             allow_missing=allow_missing, force_init=force_init)
            # 初始化优化器
            self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                                optimizer_params=optimizer_params)
    
            # 验证评估
            if validation_metric is None:
                validation_metric = eval_metric
            # str类型的eval_metric转metric.EvalMetric
            if not isinstance(eval_metric, metric.EvalMetric):
                eval_metric = metric.create(eval_metric)
    
            ################################################################################
            # training loop
            ################################################################################
            # for循环训练
            for epoch in range(begin_epoch, num_epoch):
                tic = time.time()
                # 每一轮的评估reset
                eval_metric.reset()
                # nbatch计数
                nbatch = 0
                data_iter = iter(train_data)
                end_of_batch = False
                next_data_batch = next(data_iter)
                # 循环next()获取训练集一个batch数据
                while not end_of_batch:
                    data_batch = next_data_batch
                    if monitor is not None:
                        monitor.tic()
                    # 前向传播 + 反向传播计算梯度
                    self.forward_backward(data_batch)
                    # 根据优化器梯度更新权重
                    self.update()
    
                    # 评估更新,调用metric的update
                    if isinstance(data_batch, list):
                        self.update_metric(eval_metric,
                                           [db.label for db in data_batch],
                                           pre_sliced=True)
                    else:
                        self.update_metric(eval_metric, data_batch.label)
    
                    # 获取下一个batch数据
                    try:
                        # pre fetch next batch
                        next_data_batch = next(data_iter)
                        self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
                    except StopIteration:
                        end_of_batch = True
    
                    if monitor is not None:
                        monitor.toc_print()
    
                    # 获取eval_metric列表的结果name:value
                    if end_of_batch:
                        eval_name_vals = eval_metric.get_global_name_value()
    
                    # batch结束回调
                    if batch_end_callback is not None:
                        batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                         eval_metric=eval_metric,
                                                         locals=locals())
                        for callback in _as_list(batch_end_callback):
                            callback(batch_end_params)
                    nbatch += 1
    
                # one epoch of training is finished
                # 每一个epoch结束,输出eval_metric评价列表结果, Train-xxx=xxx
                for name, val in eval_name_vals:
                    self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
                # 输出每一个epoch时间
                toc = time.time()
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
    
                # 参数同步
                # sync aux params across devices
                arg_params, aux_params = self.get_params()
                self.set_params(arg_params, aux_params)
    
                # 每一个epoch结束回调
                if epoch_end_callback is not None:
                    for callback in _as_list(epoch_end_callback):
                        callback(epoch, self.symbol, arg_params, aux_params)
    
                #----------------------------------------
                # evaluation on validation set
                # 验证集评测,validation_metric为None时与训练集的metric列表一致
                if eval_data:
                    res = self.score(eval_data, validation_metric,
                                     score_end_callback=eval_end_callback,
                                     batch_end_callback=eval_batch_end_callback, epoch=epoch)
                    #TODO: pull this into default
                    # 输出验证集评测log
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
    
                # end of 1 epoch, reset the data-iter for another epoch
                # 复位训练集数据
                train_data.reset()
    

     

    训练log源码分析

    _cb = mx.callback.Speedometer(batch_size, frequent)
    def _batch_callback(param):
      # 显示训练log,INFO:root:Epoch[26] Batch [0-20] Speed: 257.26 samples/sec acc=0.968571 lossvalue=0.331392
      _cb(param)
    class Speedometer(object):
        """Logs training speed and evaluation metrics periodically.
    
        Parameters
        ----------
        batch_size: int
            Batch size of data.
        frequent: int
            Specifies how frequently training speed and evaluation metrics
            must be logged. Default behavior is to log once every 50 batches.
        auto_reset : bool
            Reset the evaluation metrics after each log.
    
        Example
        -------
        >>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
        >>> module.fit(iterator, num_epoch=n_epoch,
        ... batch_end_callback=mx.callback.Speedometer(1, 10))
        Epoch[0] Batch [10] Speed: 1910.41 samples/sec  Train-accuracy=0.200000
        Epoch[0] Batch [20] Speed: 1764.83 samples/sec  Train-accuracy=0.400000
        Epoch[0] Batch [30] Speed: 1740.59 samples/sec  Train-accuracy=0.500000
        """
        def __init__(self, batch_size, frequent=50, auto_reset=True):
            self.batch_size = batch_size
            self.frequent = frequent
            self.init = False
            self.tic = 0
            self.last_count = 0
            self.auto_reset = auto_reset
    
        def __call__(self, param):
            """Callback to Show speed."""
            count = param.nbatch
            # 跳过nbatch=0的log输出
            if self.last_count > count:
                self.init = False
            self.last_count = count
    
            if self.init:
                # frequent个batch进行一次log输出
                if count % self.frequent == 0:
                    # #11504
                    # 计算每一个frequent训练的速度,Speed: 257.26 samples/sec代表1s能训练多少张
                    try:
                        speed = self.frequent * self.batch_size / (time.time() - self.tic)
                    except ZeroDivisionError:
                        speed = float('inf')
                    # 输出log,Speed训练速度,eval_metric列表的name:value
                    if param.eval_metric is not None:
                        # 获取模型eval_metric的计算结果name:value
                        name_value = param.eval_metric.get_name_value()
                        if self.auto_reset:
                            param.eval_metric.reset_local()
                            msg = 'Epoch[%d] Batch [%d-%d]	Speed: %.2f samples/sec'
                            msg += '	%s=%f'*len(name_value)
                            logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
                        else:
                            msg = 'Epoch[%d] Batch [0-%d]	Speed: %.2f samples/sec'
                            msg += '	%s=%f'*len(name_value)
                            logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
                    else:
                        logging.info("Iter[%d] Batch [%d]	Speed: %.2f samples/sec",
                                     param.epoch, count, speed)
                    self.tic = time.time()
            else:
                self.init = True
                self.tic = time.time()
    

      

  • 相关阅读:
    [Unity3D]蓝港面试题
    BZOJ 2186 SDOI2008 沙拉公主的困惑 数论
    JSONObject与JSONArray的使用
    一个int类型究竟占多少个字节
    软件开发的金字塔
    poj 1064 Cable master ,二分 精度!!!
    php实现工厂模式
    数据库索引的作用和长处缺点
    C++中使用class和structkeyword的不同
    提交时提示错误This Bundle is invalid.New apps and app updates submitted to the App Store must be built wit
  • 原文地址:https://www.cnblogs.com/cxt-janson/p/13502647.html
Copyright © 2011-2022 走看看