zoukankan      html  css  js  c++  java
  • py-faster R-CNN 用于训练自己的数据(1)

    官方给出的faster R-CNN的源码python版:https://github.com/rbgirshick/py-faster-rcnn

    先来分析一下 整个文件,根目录下的文件

    • caffe-fast-rcnn

    存放caffe框架

    • data

      下面有两个文件夹,第一个是demo,放了5张用于测试的图片。第二个是scripts,里面放了三个脚本文件,分别为下载在VOC2007上训练的Faster R-CNN模型、下载预训练的分类模型(ZF或者VGG16)

    和设置数据集的符号链接的脚本文件。

    • experiments

      该文件下又有三个文件夹,第一个是cfg,用来存放faster r-cnn两种训练方式alt_opt和end2end的配置文件,第二个是scripts,下面有三个脚本,分别为用于训练fast rcnn的脚本文件,用alt_opt方式训练faster rcnn的脚本文件,用end2end方式训练faster rcnn的脚本文件

    logs:训练的日志文件,在experiments/scripts脚本文件中,每一个脚本都会保存一个日志文件到这个目录下

    • lib

      存放了读取数据库的函数以及faster rcnn中的核心代码,rpn中anchor的生成、极大值抑制筛选anchor等等,之后会分析源码。

    • models

      存放了三个分类的网络,ZF、VGG16和VGG_CNN_M_1024

    • tools

      用于训练、测试、压缩fast rcnn网络

    先来看一下datasets下的imbd.py

    imdb的实例数据属性

    def __init__(self, name):
            self._name = name     #数据集的名字
            self._num_classes = 0    #该数据集要识别的物体类别的个数
            self._classes = []    #该数据集的所有类别名称构成的列表
            self._image_index = []    #数据集图片索引列表
            self._obj_proposer = 'selective_search'   #
            self._roidb = None    #词典列表,含有四个key值,分别为boxs的位置,与gt的重合度,相对应gt的类别,是否翻转
            self._roidb_handler = self.default_roidb    #
            # Use this dict for storing dataset specific config options
            self.config = {}  

    roidb的产生与gt_roidb和box_list有关,gt_roidb的获得在pascal_voc.py中的gt_roidb()方法中,而gt_roidb调用了方法_load_pascal_annotation(index)。类pascal_voc继承于imdb。

     gt_roidb = [self._load_pascal_annotation(index)
                        for index in self.image_index]

    _load_pascal_annotation(index)方法从PASCAL VOC的XML文件中读取image和bounding boxes。从这个方法也可以看到roidb中四个字典的value值是怎么设置的。

       def _load_pascal_annotation(self, index):
            filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
            tree = ET.parse(filename)    #将XML按语法生成一棵树
            objs = tree.findall('object')   #寻找所有以object为tag的数据
            if not self.config['use_diff']:   #排除所有标记为difficult的难以识别的目标
                non_diff_objs = [
                    obj for obj in objs if int(obj.find('difficult').text) == 0]
                # if len(non_diff_objs) != len(objs):
                #     print 'Removed {} difficult objects'.format(
                #         len(objs) - len(non_diff_objs))
                objs = non_diff_objs 
            num_objs = len(objs)     #ground-truth的总数
    
            boxes = np.zeros((num_objs, 4), dtype=np.uint16)
            gt_classes = np.zeros((num_objs), dtype=np.int32)
            overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
            # "Seg" area for pascal is just the box area
            seg_areas = np.zeros((num_objs), dtype=np.float32)
    
            # Load object bounding boxes into a data frame.
            for ix, obj in enumerate(objs):
                bbox = obj.find('bndbox')  #左上角和右下角
                x1 = float(bbox.find('xmin').text) - 1   #box的下标是从1开始的,为了建立以0为原点的坐标所以减一
                y1 = float(bbox.find('ymin').text) - 1
                x2 = float(bbox.find('xmax').text) - 1
                y2 = float(bbox.find('ymax').text) - 1
                cls = self._class_to_ind[obj.find('name').text.lower().strip()]   #去除类别名的首尾空格找到该类别的索引
                boxes[ix, :] = [x1, y1, x2, y2]
                gt_classes[ix] = cls
                overlaps[ix, cls] = 1.0     #gt_roibd所在类别的重合率是1,其他类别为0
                seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
    
            overlaps = scipy.sparse.csr_matrix(overlaps)  #将overlaps的稀疏矩阵压缩
    
            return {'boxes' : boxes,
                    'gt_classes': gt_classes,
                    'gt_overlaps' : overlaps,
                    'flipped' : False,
                    'seg_areas' : seg_areas}

    然后就是box_list是每张image中含有的box

    def create_roidb_from_box_list(self, box_list, gt_roidb):
            assert len(box_list) == self.num_images, 
                    'Number of boxes must match number of ground-truth images'
            roidb = []
            for i in xrange(self.num_images):
                boxes = box_list[i]
                num_boxes = boxes.shape[0]    #该image中含有的box个数
                overlaps = np.zeros((num_boxes, self.num_classes), dtype=np.float32)
    
                if gt_roidb is not None and gt_roidb[i]['boxes'].size > 0:
                    gt_boxes = gt_roidb[i]['boxes']
                    gt_classes = gt_roidb[i]['gt_classes']
                    gt_overlaps = bbox_overlaps(boxes.astype(np.float),
                                                gt_boxes.astype(np.float))
                    argmaxes = gt_overlaps.argmax(axis=1)     #获得所有boxes重叠率最高的gt_box的类别索引
                    maxes = gt_overlaps.max(axis=1)    #与argmax对应的重叠率
                    I = np.where(maxes > 0)[0]     #去掉重叠率小于1的box
                    overlaps[I, gt_classes[argmaxes[I]]] = maxes[I]
    
                overlaps = scipy.sparse.csr_matrix(overlaps)   
                roidb.append({
                    'boxes' : boxes,
                    'gt_classes' : np.zeros((num_boxes,), dtype=np.int32),
                    'gt_overlaps' : overlaps,
                    'flipped' : False,
                    'seg_areas' : np.zeros((num_boxes,), dtype=np.float32),
                })
            return roidb

    对roidb进行翻转操作,如果想扩充自己的数据集,可以在这个方法上添加修改

    def append_flipped_images(self):
            num_images = self.num_images
            widths = self._get_widths()
            for i in xrange(num_images):
                boxes = self.roidb[i]['boxes'].copy()
                oldx1 = boxes[:, 0].copy()
                oldx2 = boxes[:, 2].copy()
                boxes[:, 0] = widths[i] - oldx2 - 1
                boxes[:, 2] = widths[i] - oldx1 - 1
                assert (boxes[:, 2] >= boxes[:, 0]).all()
                entry = {'boxes' : boxes,
                         'gt_overlaps' : self.roidb[i]['gt_overlaps'],
                         'gt_classes' : self.roidb[i]['gt_classes'],
                         'flipped' : True}
                self.roidb.append(entry)
            self._image_index = self._image_index * 2

    继续来看pascal_voc.py,这个类是继承于imbd的。

    在这个类中image是通过索引获得的,图片的命名形式如下:  

            self._devkit_path = self._get_default_path() if devkit_path is None 
                                else devkit_path
            self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
         image_path = os.path.join(self._data_path, 'JPEGImages',
                                      index + self._image_ext)

    _year和devkit_path都是初始化实例时传进的字符串,_image_ext属性是图片的后缀名,默认为“.jpg”。

    函数

  • 相关阅读:
    一般索引
    微信小程序扫码
    微信小程序
    PHPStudy设置局域网访问
    phpstudy
    爱番番
    织梦栏目url的seo处理
    织梦dedecms网站迁移搬家图文教程
    打开存储过程中的代码目录(转)
    正在载入
  • 原文地址:https://www.cnblogs.com/catpainter/p/8502768.html
Copyright © 2011-2022 走看看