zoukankan      html  css  js  c++  java
  • faster-rcnn代码阅读-训练整体流程

    二、训练

    接下来回到train.py第160行,通过调用sw.train_model方法进行训练:

     1     def train_model(self, max_iters):
     2         """Network training loop."""
     3         last_snapshot_iter = -1
     4         timer = Timer()
     5         model_paths = []
     6         while self.solver.iter < max_iters:
     7             # Make one SGD update
     8             timer.tic()
     9             self.solver.step(1)
    10             timer.toc()
    11             if self.solver.iter % (10 * self.solver_param.display) == 0:
    12                 print 'speed: {:.3f}s / iter'.format(timer.average_time)
    13 
    14             if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
    15                 last_snapshot_iter = self.solver.iter
    16                 model_paths.append(self.snapshot())
    17 
    18         if last_snapshot_iter != self.solver.iter:
    19             model_paths.append(self.snapshot())
    20         return model_paths

    方法中的self.solver.step(1)即是网络进行一次前向传播和反向传播。前向传播时,数据流会从第一层流动到最后一层,最后计算出loss,然后loss相对于各层输入的梯度会从最后一层计算回第一层。下面逐层来介绍faster-rcnn算法的运行过程。

    2.1、input-data layer

    第一层是由python代码构成的,其prototxt描述为:

    layer {
      name: 'input-data'
      type: 'Python'
      top: 'data'
      top: 'im_info'
      top: 'gt_boxes'
      python_param {
        module: 'roi_data_layer.layer'
        layer: 'RoIDataLayer'
        param_str: "'num_classes': 2"
      }
    }

    从中可以看出,input-data层有三个输出:data、im_info、gt_boxes,其实现为RoIDataLayer类。这一层对数据的预处理操作为:对图片进行长宽等比例缩放,使短边缩放至600;如果缩放后,长边的长度大于1000,则以长边为基准,将长边缩放至1000,短边作相应的等比例缩放。这一层的3个输出分别为:

    1、data:1, 3, h, w(一个batch只支持输入一张图)

    2、im_info: im_info[0], im_info[1], im_info[2]分别为h, w, target_size/im_origin_size(缩放比例)

    3、gt_boxes: (x1, y1, x2, y2, cls)

    预处理部分涉及到的函数有_get_next_minibatchget_minibatch_get_image_blobprep_im_for_blobim_list_to_blob

    网络在构造过程中(即self.solver = caffe.SGDSolver(solver_prototxt))会调用该类的setup方法:

     1 __C.TRAIN.IMS_PER_BATCH = 1
     2 __C.TRAIN.SCALES = [600]
     3 __C.TRAIN.MAX_SIZE = 1000
     4 __C.TRAIN.HAS_RPN = True
     5 __C.TRAIN.BBOX_REG = True
     6 
     7     def setup(self, bottom, top):
     8         """Setup the RoIDataLayer."""
     9 
    10         # parse the layer parameter string, which must be valid YAML
    11         layer_params = yaml.load(self.param_str_)
    12 
    13         self._num_classes = layer_params['num_classes']
    14 
    15         self._name_to_top_map = {}
    16 
    17         # data blob: holds a batch of N images, each with 3 channels
    18         idx = 0
    19         top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3,
    20             max(cfg.TRAIN.SCALES), cfg.TRAIN.MAX_SIZE)
    21         self._name_to_top_map['data'] = idx
    22         idx += 1
    23 
    24         if cfg.TRAIN.HAS_RPN:
    25             top[idx].reshape(1, 3)
    26             self._name_to_top_map['im_info'] = idx
    27             idx += 1
    28 
    29             top[idx].reshape(1, 4)
    30             self._name_to_top_map['gt_boxes'] = idx
    31             idx += 1
    32         else: # not using RPN
    33             # rois blob: holds R regions of interest, each is a 5-tuple
    34             # (n, x1, y1, x2, y2) specifying an image batch index n and a
    35             # rectangle (x1, y1, x2, y2)
    36             top[idx].reshape(1, 5)
    37             self._name_to_top_map['rois'] = idx
    38             idx += 1
    39 
    40             # labels blob: R categorical labels in [0, ..., K] for K foreground
    41             # classes plus background
    42             top[idx].reshape(1)
    43             self._name_to_top_map['labels'] = idx
    44             idx += 1
    45 
    46             if cfg.TRAIN.BBOX_REG:
    47                 # bbox_targets blob: R bounding-box regression targets with 4
    48                 # targets per class
    49                 top[idx].reshape(1, self._num_classes * 4)
    50                 self._name_to_top_map['bbox_targets'] = idx
    51                 idx += 1
    52 
    53                 # bbox_inside_weights blob: At most 4 targets per roi are active;
    54                 # thisbinary vector sepcifies the subset of active targets
    55                 top[idx].reshape(1, self._num_classes * 4)
    56                 self._name_to_top_map['bbox_inside_weights'] = idx
    57                 idx += 1
    58 
    59                 top[idx].reshape(1, self._num_classes * 4)
    60                 self._name_to_top_map['bbox_outside_weights'] = idx
    61                 idx += 1
    62 
    63         print 'RoiDataLayer: name_to_top:', self._name_to_top_map
    64         assert len(top) == len(self._name_to_top_map)

    主要是对输出的shape进行定义。要说明的是,在前向传播的过程中,仍然会对输出的各top的shape进行重定义,并且二者定义的shape往往都是不同的。

  • 相关阅读:
    二逼青年暑假深圳面试记
    poj2032Square Carpets(IDA* + dancing links)
    JBoss 系列七十:一个简单的 CDI Web 应用
    cocos2d-x 截取屏幕可见区域
    HDU3863:No Gambling
    SQL Server配置管理WMI问题
    Inno_setup制作升级包必须面临的几个问题
    Log4j发送邮件
    为github帐号添加SSH keys(Linux和Windows)
    Ubuntu常用命令
  • 原文地址:https://www.cnblogs.com/pursuiting/p/10129049.html
Copyright © 2011-2022 走看看