zoukankan      html  css  js  c++  java
  • mxnet modle.fit源码解析

     model.fit源码分析

    首先来到module模块中,即https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/module,进入base_module.py中,我们便可以看到fit()的原型。

    class BaseModule(object):
        ################################################################################
        # High Level API
        ################################################################################
        def forward_backward(self, data_batch):
            """A convenient function that calls both ``forward`` and ``backward``."""
            self.forward(data_batch, is_train=True)
            self.backward()
    
        # 验证集评测
        def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
                  score_end_callback=None,
                  reset=True, epoch=0, sparse_row_id_fn=None):
            """Runs prediction on ``eval_data`` and evaluates the performance according to
            the given ``eval_metric``.
    
            Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
            to see an end-to-end use-case.
    
            Parameters
            ----------
            eval_data : DataIter
                Evaluation data to run prediction on.
            eval_metric : EvalMetric or list of EvalMetrics
                Evaluation metric to use.
            num_batch : int
                Number of batches to run. Defaults to ``None``, indicating run until the `DataIter`
                finishes.
            batch_end_callback : function
                Could also be a list of functions.
            reset : bool
                Defaults to ``True``. Indicates whether we should reset `eval_data` before starting
                evaluating.
            epoch : int
                Defaults to 0. For compatibility, this will be passed to callbacks (if any).
                During training, this will correspond to the training epoch number.
            sparse_row_id_fn : A callback function
                The function  takes `data_batch` as an input and returns a dict of
                str -> NDArray. The resulting dict is used for pulling row_sparse
                parameters from the kvstore, where the str key is the name of the param,
                and the value is the row id of the param to pull.
    
            Examples
            --------
            >>> # An example of using score for prediction.
            >>> # Evaluate accuracy on val_dataiter
            >>> metric = mx.metric.Accuracy()
            >>> mod.score(val_dataiter, metric)
            >>> mod.score(val_dataiter, ['mse', 'acc'])
            """
            assert self.binded and self.params_initialized
    
            # reset验证集
            if reset:
                eval_data.reset()
    
            if not isinstance(eval_metric, metric.EvalMetric):
                eval_metric = metric.create(eval_metric)
    
            eval_metric.reset()
            actual_num_batch = 0
    
            # 验证集batch获取
            for nbatch, eval_batch in enumerate(eval_data):
                if num_batch is not None and nbatch == num_batch:
                    break
                # 模型加载数据集
                self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
                # 前向传播
                self.forward(eval_batch, is_train=False)
                # 调用metric列表update函数
                if isinstance(eval_batch, list):
                    self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
                else:
                    self.update_metric(eval_metric, eval_batch.label)
    
                # batch结束回调
                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                actual_num_batch += 1
    
            # 验证集评测结束回调
            if score_end_callback:
                params = BatchEndParam(epoch=epoch,
                                       nbatch=actual_num_batch,
                                       eval_metric=eval_metric,
                                       locals=locals())
                for callback in _as_list(score_end_callback):
                    callback(params)
    
            # 返回metric列表结果name:value
            return eval_metric.get_name_value()
    
        def fit(self, train_data, eval_data=None, eval_metric='acc',
                epoch_end_callback=None, batch_end_callback=None, kvstore='local',
                optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
                eval_end_callback=None,
                eval_batch_end_callback=None, initializer=Uniform(0.01),
                arg_params=None, aux_params=None, allow_missing=False,
                force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
                validation_metric=None, monitor=None, sparse_row_id_fn=None):
            """Trains the module parameters.
    
            Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
            to see an end-to-end use-case.
    
            Parameters
            ----------
            train_data : DataIter
                训练集数据迭代器
            eval_data : DataIter
                如果不是'None',将用作验证集,并将评估每个时期之后的性能。
            eval_metric : str or EvalMetric
                默认是字符串'accuracy'.训练期间用来显示的绩效指标。
                其他可能的预定义指标是:'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
            epoch_end_callback : function or list of functions
                每个epoch结束时回调,参数 `epoch`, `symbol`, `arg_params`and `aux_params`
            batch_end_callback : function or list of function
                每个batch结束时回调,参数 `BatchEndParam`.
            kvstore : str or KVStore
                参数更新设备,默认值'local'.
                "device",GPU计算梯度更新权重
                "local",CPU更新
                "dist_device_sync",分布式训练
            optimizer : str or Optimizer
                优化器,默认值'sgd'.
            optimizer_params : dict
                优化器参数,默认值(('learning_rate', 0.01),)
            eval_end_callback : function or list of function
                evaluation全跑完回调
            eval_batch_end_callback : function or list of function
                evaluation一个batch跑完回调
            initializer : Initializer
                如果尚未初始化模块参数,则调用初始化程序来初始化它们
            arg_params : dict
                默认None, 值不为None,则替代initializer初始化参数
            aux_params : dict
                默认None, 值不为None,则替代initializer初始化参数
            allow_missing : bool
                默认False,是否允许丢失参数
                指示当arg_params和aux_params不为None时是否允许缺少参数。
                allow_missing=True,那么缺少的参数将通过initializer进行初始化。
            force_rebind : bool
                默认False
                如果已经绑定执行器,是否强制重新绑定执行器。
            force_init : bool
                默认False
                指示即使参数已经初始化也是否强制初始化。
            begin_epoch : int
                默认值0
                指示开始epoch。通常,如果从前一个训练阶段在Epoch[n]保存,重新训练则该值应为n+1
            num_epoch : int
                训练的epoch数量
            sparse_row_id_fn : A callback function
                The function  takes `data_batch` as an input and returns a dict of
                str -> NDArray. The resulting dict is used for pulling row_sparse
                parameters from the kvstore, where the str key is the name of the param,
                and the value is the row id of the param to pull.
    
            Examples
            --------
            >>> # An example of using fit for training.
            >>> # Assume training dataIter and validation dataIter are ready
            >>> # Assume loading a previously checkpointed model
            >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
            >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
            ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
            ...     arg_params=arg_params, aux_params=aux_params,
            ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
            """
            assert num_epoch is not None, 'please specify number of epochs'
    
            # 绑定训练集数据symbols name
            self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                      for_training=True, force_rebind=force_rebind)
            if monitor is not None:
                self.install_monitor(monitor)
            # 初始化权重参数,初始化策略参考以上的参数说明
            self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                             allow_missing=allow_missing, force_init=force_init)
            # 初始化优化器
            self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                                optimizer_params=optimizer_params)
    
            # 验证评估
            if validation_metric is None:
                validation_metric = eval_metric
            # str类型的eval_metric转metric.EvalMetric
            if not isinstance(eval_metric, metric.EvalMetric):
                eval_metric = metric.create(eval_metric)
    
            ################################################################################
            # training loop
            ################################################################################
            # for循环训练
            for epoch in range(begin_epoch, num_epoch):
                tic = time.time()
                # 每一轮的评估reset
                eval_metric.reset()
                # nbatch计数
                nbatch = 0
                data_iter = iter(train_data)
                end_of_batch = False
                next_data_batch = next(data_iter)
                # 循环next()获取训练集一个batch数据
                while not end_of_batch:
                    data_batch = next_data_batch
                    if monitor is not None:
                        monitor.tic()
                    # 前向传播 + 反向传播计算梯度
                    self.forward_backward(data_batch)
                    # 根据优化器梯度更新权重
                    self.update()
    
                    # 评估更新,调用metric的update
                    if isinstance(data_batch, list):
                        self.update_metric(eval_metric,
                                           [db.label for db in data_batch],
                                           pre_sliced=True)
                    else:
                        self.update_metric(eval_metric, data_batch.label)
    
                    # 获取下一个batch数据
                    try:
                        # pre fetch next batch
                        next_data_batch = next(data_iter)
                        self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
                    except StopIteration:
                        end_of_batch = True
    
                    if monitor is not None:
                        monitor.toc_print()
    
                    # 获取eval_metric列表的结果name:value
                    if end_of_batch:
                        eval_name_vals = eval_metric.get_global_name_value()
    
                    # batch结束回调
                    if batch_end_callback is not None:
                        batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                         eval_metric=eval_metric,
                                                         locals=locals())
                        for callback in _as_list(batch_end_callback):
                            callback(batch_end_params)
                    nbatch += 1
    
                # one epoch of training is finished
                # 每一个epoch结束,输出eval_metric评价列表结果, Train-xxx=xxx
                for name, val in eval_name_vals:
                    self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
                # 输出每一个epoch时间
                toc = time.time()
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
    
                # 参数同步
                # sync aux params across devices
                arg_params, aux_params = self.get_params()
                self.set_params(arg_params, aux_params)
    
                # 每一个epoch结束回调
                if epoch_end_callback is not None:
                    for callback in _as_list(epoch_end_callback):
                        callback(epoch, self.symbol, arg_params, aux_params)
    
                #----------------------------------------
                # evaluation on validation set
                # 验证集评测,validation_metric为None时与训练集的metric列表一致
                if eval_data:
                    res = self.score(eval_data, validation_metric,
                                     score_end_callback=eval_end_callback,
                                     batch_end_callback=eval_batch_end_callback, epoch=epoch)
                    #TODO: pull this into default
                    # 输出验证集评测log
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
    
                # end of 1 epoch, reset the data-iter for another epoch
                # 复位训练集数据
                train_data.reset()
    

     

    训练log源码分析

    _cb = mx.callback.Speedometer(batch_size, frequent)
    def _batch_callback(param):
      # 显示训练log,INFO:root:Epoch[26] Batch [0-20] Speed: 257.26 samples/sec acc=0.968571 lossvalue=0.331392
      _cb(param)
    class Speedometer(object):
        """Logs training speed and evaluation metrics periodically.
    
        Parameters
        ----------
        batch_size: int
            Batch size of data.
        frequent: int
            Specifies how frequently training speed and evaluation metrics
            must be logged. Default behavior is to log once every 50 batches.
        auto_reset : bool
            Reset the evaluation metrics after each log.
    
        Example
        -------
        >>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
        >>> module.fit(iterator, num_epoch=n_epoch,
        ... batch_end_callback=mx.callback.Speedometer(1, 10))
        Epoch[0] Batch [10] Speed: 1910.41 samples/sec  Train-accuracy=0.200000
        Epoch[0] Batch [20] Speed: 1764.83 samples/sec  Train-accuracy=0.400000
        Epoch[0] Batch [30] Speed: 1740.59 samples/sec  Train-accuracy=0.500000
        """
        def __init__(self, batch_size, frequent=50, auto_reset=True):
            self.batch_size = batch_size
            self.frequent = frequent
            self.init = False
            self.tic = 0
            self.last_count = 0
            self.auto_reset = auto_reset
    
        def __call__(self, param):
            """Callback to Show speed."""
            count = param.nbatch
            # 跳过nbatch=0的log输出
            if self.last_count > count:
                self.init = False
            self.last_count = count
    
            if self.init:
                # frequent个batch进行一次log输出
                if count % self.frequent == 0:
                    # #11504
                    # 计算每一个frequent训练的速度,Speed: 257.26 samples/sec代表1s能训练多少张
                    try:
                        speed = self.frequent * self.batch_size / (time.time() - self.tic)
                    except ZeroDivisionError:
                        speed = float('inf')
                    # 输出log,Speed训练速度,eval_metric列表的name:value
                    if param.eval_metric is not None:
                        # 获取模型eval_metric的计算结果name:value
                        name_value = param.eval_metric.get_name_value()
                        if self.auto_reset:
                            param.eval_metric.reset_local()
                            msg = 'Epoch[%d] Batch [%d-%d]	Speed: %.2f samples/sec'
                            msg += '	%s=%f'*len(name_value)
                            logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
                        else:
                            msg = 'Epoch[%d] Batch [0-%d]	Speed: %.2f samples/sec'
                            msg += '	%s=%f'*len(name_value)
                            logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
                    else:
                        logging.info("Iter[%d] Batch [%d]	Speed: %.2f samples/sec",
                                     param.epoch, count, speed)
                    self.tic = time.time()
            else:
                self.init = True
                self.tic = time.time()
    

      

  • 相关阅读:
    codeforces C. Cows and Sequence 解题报告
    codeforces A. Point on Spiral 解题报告
    codeforces C. New Year Ratings Change 解题报告
    codeforces A. Fox and Box Accumulation 解题报告
    codeforces B. Multitasking 解题报告
    git命令使用
    shell简单使用
    知识束缚
    php 调用系统命令
    数据传输方式(前端与后台 ,后台与后台)
  • 原文地址:https://www.cnblogs.com/cxt-janson/p/13502647.html
Copyright © 2011-2022 走看看