zoukankan      html  css  js  c++  java
  • pyfasterrcnn代码阅读1train_net.py & train.py

    # train_net.py
    #
    !/usr/bin/env python # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- """Train a Fast R-CNN network on a region of interest database.""" import _init_paths from fast_rcnn.train import get_training_roidb, train_net from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir from datasets.factory import get_imdb import datasets.imdb import caffe import argparse import pprint import numpy as np import sys def parse_args():
    # 运行时命令行参数
    """ Parse input arguments """ parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=40000, type=int) parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='voc_2007_trainval', type=str) parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args() return args def combined_roidb(imdb_names):
    # 融合roidb,roidb来自于数据集(实验可能用到多个),所以需要combine多个数据集的roidb
    #
    def get_roidb(imdb_name): imdb = get_imdb(imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    # 设置proposal方法,这里是selective search(config.py) imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
         # 得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性 roidb
    = get_training_roidb(imdb) return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')] roidb = roidbs[0]
    # 这里进行combine roidb
    if len(roidbs) > 1: for r in roidbs[1:]: roidb.extend(r) imdb = datasets.imdb.imdb(imdb_names) else:
    # get_imdb方法定义在dataset/factory.py,通过名字得到imdb imdb
    = get_imdb(imdb_names) return imdb, roidb if __name__ == '__main__': args = parse_args() print('Called with args:') print(args)
    #定义在config.py,见py-faster-rcnn代码阅读2-config.py
    i
    f args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) cfg.GPU_ID = args.gpu_id print('Using config:') pprint.pprint(cfg) if not args.randomize: # fix the random seeds (numpy and caffe) for reproducibility np.random.seed(cfg.RNG_SEED) caffe.set_random_seed(cfg.RNG_SEED) # set up caffe caffe.set_mode_gpu() caffe.set_device(args.gpu_id) imdb, roidb = combined_roidb(args.imdb_name) print '{:d} roidb entries'.format(len(roidb)) output_dir = get_output_dir(imdb) print 'Output will be saved to `{:s}`'.format(output_dir) #定义在train.py train_net(args.solver, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)
    # train.py
    
    """Train a Fast R-CNN network."""
    
    import caffe
    from fast_rcnn.config import cfg
    import roi_data_layer.roidb as rdl_roidb
    from utils.timer import Timer
    import numpy as np
    import os
    
    from caffe.proto import caffe_pb2
    import google.protobuf as pb2
    
    class SolverWrapper(object):
        # 对caffe solver的简单封装,封装允许我们控制snapshot过程,用于去除归一化学习到的bb回归权重
        """A simple wrapper around Caffe's solver.
        This wrapper gives us control over the snapshotting process, which we
        use to unnormalize the learned bounding-box regression weights.
        """
    
        def __init__(self, solver_prototxt, roidb, output_dir,
                     pretrained_model=None):
            """Initialize the SolverWrapper."""
            self.output_dir = output_dir
    
            if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
                cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
                # RPN can only use precomputed normalization because there are no
                # fixed statistics to compute a priori
    
                assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
            # 求bb回归器的target,返回box的均值和标准差
            if cfg.TRAIN.BBOX_REG:
                print 'Computing bounding-box regression targets...'
                self.bbox_means, self.bbox_stds = \
                        rdl_roidb.add_bbox_regression_targets(roidb)
                print 'done'
            # solver,copy预训练模型的权重,set roidb
            self.solver = caffe.SGDSolver(solver_prototxt)
            if pretrained_model is not None:
                print ('Loading pretrained model '
                       'weights from {:s}').format(pretrained_model)
                self.solver.net.copy_from(pretrained_model)
    
            self.solver_param = caffe_pb2.SolverParameter()
            with open(solver_prototxt, 'rt') as f:
                pb2.text_format.Merge(f.read(), self.solver_param)
    
            self.solver.net.layers[0].set_roidb(roidb)
    
        def snapshot(self):
            """Take a snapshot of the network after unnormalizing the learned
            bounding-box regression weights. This enables easy use at test-time.
            """
            # 给snapshot的bb预测参数做去归一化
            net = self.solver.net
    
            scale_bbox_params = (cfg.TRAIN.BBOX_REG and
                                 cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
                                 net.params.has_key('bbox_pred'))
    
            if scale_bbox_params:
                # save original values
                orig_0 = net.params['bbox_pred'][0].data.copy()
                orig_1 = net.params['bbox_pred'][1].data.copy()
    
                # scale and shift with bbox reg unnormalization; then save snapshot
                # 去除归一化,乘标准差,加上均值
                net.params['bbox_pred'][0].data[...] = \
                        (net.params['bbox_pred'][0].data *
                         self.bbox_stds[:, np.newaxis])
                net.params['bbox_pred'][1].data[...] = \
                        (net.params['bbox_pred'][1].data *
                         self.bbox_stds + self.bbox_means)
            # 给snapshot命名
            infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                     if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
            filename = (self.solver_param.snapshot_prefix + infix +
                        '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
            filename = os.path.join(self.output_dir, filename)
            # 存snapshot
            net.save(str(filename))
            print 'Wrote snapshot to: {:s}'.format(filename)
    
            # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态
            if scale_bbox_params:
                # restore net to original state
                net.params['bbox_pred'][0].data[...] = orig_0
                net.params['bbox_pred'][1].data[...] = orig_1
            return filename
    
        def train_model(self, max_iters):
            """Network training loop."""
            last_snapshot_iter = -1
            timer = Timer()
            model_paths = []
            while self.solver.iter < max_iters:
                # Make one SGD update
                timer.tic()
                self.solver.step(1)
                timer.toc()
                if self.solver.iter % (10 * self.solver_param.display) == 0:
                    print 'speed: {:.3f}s / iter'.format(timer.average_time)
                # 达到预设次数保存snapshot
                if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                    last_snapshot_iter = self.solver.iter
                    model_paths.append(self.snapshot())
            # 整体迭代完成后也要存snapshot
            if last_snapshot_iter != self.solver.iter:
                model_paths.append(self.snapshot())
            return model_paths
    
    def get_training_roidb(imdb):
        """Returns a roidb (Region of Interest database) for use in training."""
        if cfg.TRAIN.USE_FLIPPED:
            print 'Appending horizontally-flipped training examples...'
            # 水平翻转,得到更多的roidb
            imdb.append_flipped_images()
            print 'done'
    
        print 'Preparing training data...'
        # 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes...
        rdl_roidb.prepare_roidb(imdb)
        print 'done'
    
        return imdb.roidb
    
    def filter_roidb(roidb):
        """Remove roidb entries that have no usable RoIs."""
        # 删掉没用的RoIs, 有效的图片必须各有前景和背景ROI
        def is_valid(entry):
            # Valid images have:
            #   (1) At least one foreground RoI OR
            #   (2) At least one background RoI
            overlaps = entry['max_overlaps']
            # find boxes with sufficient overlap
            fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
            # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
            bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                               (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
            # image is only valid if such boxes exist
            valid = len(fg_inds) > 0 or len(bg_inds) > 0
            return valid
    
        num = len(roidb)
        filtered_roidb = [entry for entry in roidb if is_valid(entry)]
        num_after = len(filtered_roidb)
        print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                           num, num_after)
        return filtered_roidb
    
    def train_net(solver_prototxt, roidb, output_dir,
                  pretrained_model=None, max_iters=40000):
        """Train a Fast R-CNN network."""
        # 训练Fast R-CNN
        roidb = filter_roidb(roidb)
        sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                           pretrained_model=pretrained_model)
    
        print 'Solving...'
        # 训练...
        model_paths = sw.train_model(max_iters)
        print 'done solving'
        return model_paths
  • 相关阅读:
    38丨WebSocket:沙盒里的TCP
    Jmeter安装与介绍(一)
    37丨CDN:加速我们的网络服务
    爬虫笔记:xpath和lxml(十二)
    爬虫笔记:Selenium(十一)
    36丨WAF:保护我们的网络服务
    35丨OpenResty:更灵活的Web服务器
    爬虫笔记:抓取qq群成员的头像和昵称生成词云(十)
    Python全栈工程师 (类变量、方法、继承、覆盖)
    Python全栈工程师(面向对象)
  • 原文地址:https://www.cnblogs.com/alanma/p/6802835.html
Copyright © 2011-2022 走看看