zoukankan      html  css  js  c++  java
  • pytorch faster_rcnn

    代码地址:https://github.com/jwyang/faster-rcnn.pytorch

    1.fasterRCNN.train():这个不是让网络进行训练,而是让module in training mode,有些module在traing model和testing model下不同,比如bn

    即self.training这个成员变量为true(这个成员变量属于nn.Module,fasterRCNN继承了这个成员变量),以下是train成员函数的源码

    2.bn的train和test不同,train的时候应该是要学习参数的,test的时候关闭,pytorch的用法如下:

    pytorch的batchnorm使用时需要小心,training和track_running_stats可以组合出三种behavior,很容易掉坑里(我刚发现我对track_running_stats的理解错了)。

    1. training=True, track_running_stats=True, 这是常用的training时期待的行为,running_mean 和running_var会跟踪不同batch数据的mean和variance。
    2. training=True, track_running_stats=False, 这时候batchnorm不跟踪跨batch数据的statistics了,而是用每个batch的mean和variance做normalization。
    3. training=False, track_running_stats=True, 这是我们期待的test时候的行为,即使用training阶段估计的running_mean 和running_var.
    4. training=False, track_running_stats=False,同2(!!!).
    https://www.zhihu.com/question/282672547/answer/529154567李韶华的回答
    3.class_agnostic == true就是所有类别回归同一个坐标,也就是一个框回归一个坐标
            == false是每个类别单独回归4个坐标
        if self.class_agnostic:
          self.RCNN_bbox_pred = nn.Linear(4096, 4)
        else:
          self.RCNN_bbox_pred = nn.Linear(4096, 4 * self.n_classes)
    4.真正开始训练的代码不是fasterRCNN.train(),而是下面这段代码:
          rois, cls_prob, bbox_pred, 
          rpn_loss_cls, rpn_loss_box, 
          RCNN_loss_cls, RCNN_loss_bbox, 
          rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

    fasterRCNN是一个实例,应该是没办法进行调用的,但实际上这段代码执行的是forward函数。为什么?其实就是python的括号重载。fasterRCNN这个实例继承于nn.Module类,这个类定义了forward成员函数,nn.Module类使用了__call__进行了重载,让实例能够调用,并且调用的函数是forward函数,具体代码见下面的源码:

    python中__call__函数的作用是使实例能够像函数一样被调用https://blog.csdn.net/Yaokai_AssultMaster/article/details/70256621,也称之为括号重载,即‘()’

        def __call__(self, *input, **kwargs):
            for hook in self._forward_pre_hooks.values():
                hook(self, input)
            if torch._C._get_tracing_state():
                result = self._slow_forward(*input, **kwargs)
            else:
                result = self.forward(*input, **kwargs)
            for hook in self._forward_hooks.values():
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    raise RuntimeError(
                        "forward hooks should never return any values, but '{}'"
                        "didn't return None".format(hook))
            if len(self._backward_hooks) > 0:
                var = result
                while not isinstance(var, torch.Tensor):
                    if isinstance(var, dict):
                        var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                    else:
                        var = var[0]
                grad_fn = var.grad_fn
                if grad_fn is not None:
                    for hook in self._backward_hooks.values():
                        wrapper = functools.partial(hook, self)
                        functools.update_wrapper(wrapper, hook)
                        grad_fn.register_hook(wrapper)
            return result

    nn.Module定义了一个forward的成员函数,这个函数在基类中没有实现,而是在各个子类自己实现的,每个子类都必须实现forward函数:

        def forward(self, *input):
            r"""Defines the computation performed at every call.
            Should be overridden by all subclasses.
            .. note::
                Although the recipe for forward pass needs to be defined within
                this function, one should call the :class:`Module` instance afterwards
                instead of this since the former takes care of running the
                registered hooks while the latter silently ignores them.
            """
            raise NotImplementedError

    子类调用forward函数不能直接用calss.forward(),而是用实例的函数调用,具体的原因好像是hook,这个在上面__call__函数中也看到调用forward使用了跟hook有关的input

     

  • 相关阅读:
    Northwind测试学习用数据库
    DataTables在回调方法中使用api
    DataTables获取表单输入框数据
    DataTables选择行并删除(删除单行)
    DataTables选择多行
    DataTables给每一列添加下拉框搜索
    AngularJS 父子控制器
    更改AngularJS的语法解析符号
    AngularJS中的控制器示例_3
    AngularJS中的控制器示例_2
  • 原文地址:https://www.cnblogs.com/ymjyqsx/p/10084507.html
Copyright © 2011-2022 走看看