zoukankan      html  css  js  c++  java
  • 使用ImageNet在faster-rcnn上训练自己的分类网络

    具体代码见https://github.com/zhiyishou/py-faster-rcnn


    这是我对cup, glasses训练的识别

    faster-rcnn在fast-rcnn的基础上加了rpn来将整个训练都置于GPU内,以用来提高效率,这里我们将使用ImageNet的数据集来在faster-rcnn上来训练自己的分类器。从ImageNet上可下载到很多类别的Image与bounding box annotation来进行训练(每一个类别下的annotation都少于等于image的个数,所以我们从annotation来建立索引)。

    lib/dataset/factory.py中提供了coco与voc的数据集获取方法,而我们要做的就是在这里加上我们自己的ImageNet获取方法,我们先来建立ImageNet数据获取主文件。coco与pascal_voc的获取都是继承于父类imdb,所以我们可根据pascal_voc的获取方法来做模板修改完成我们的ImageNet类。

    创建ImageNet类

    由于在faster-rcnn里使用rpn来代替了selective_search,所以我们可以在使用时直接略过有关selective_search的方法,根据pascal_voc类做模板,我们需要留下的方法有:

    __init__ //初始化
    image_path_at //根据数据集列表的index来取图片绝对地址
    image_path_from_index //配合上面
    _load_image_set_index //获取数据集列表
    _gt_roidb //获取ground-truth数据
    rpn_roidb //获取region proposal数据
    _load_rpn_roidb //根据gt_roidb生成rpn_roidb数据并合成
    _load_psacal_annotation //加载annotation文件并对bounding box进行数据整理
    

    __init__:

    def __init__(self, image_set):
            imdb.__init__(self, 'imagenet')
            self._image_set = image_set
            self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")
            #类别与对应的wnid,可以修改成自己要训练的类别
            self._class_wnids = {
                'cup': 'n03147509',
                'glasses': 'n04272054'
            }
    
            #类别,修改类别时同时要修改这里
            self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses'])
            self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
            #bounding box annotation 文件的目录
            self._xml_path = os.path.join(self._data_path, "Annotations")
            self._image_ext = '.JPEG'
            #我们使用xml文件名来做数据集的索引
            # the xml file name and each one corresponding to image file name
            self._image_index = self._load_xml_filenames()
            self._salt = str(uuid.uuid4())
            self._comp_id = 'comp4'
    
            self.config = {'cleanup'     : True,
                           'use_salt'    : True,
                           'use_diff'    : False,
                           'matlab_eval' : False,
                           'rpn_file'    : None,
                           'min_size'    : 2}
    
            assert os.path.exists(self._data_path), 
                    'Path does not exist: {}'.format(self._data_path)
    

    image_path_at

    def image_path_at(self, i):
            #使用index来从xml_filenames取到filename,生成绝对路径
            return self.image_path_from_image_filename(self._image_index[i])
    

    image_path_from_image_filename(类似pascal_voc中的image_path_from_index)

    def image_path_from_image_filename(self, image_filename):
            image_path = os.path.join(self._data_path, 'Images',
                                      image_filename + self._image_ext)
            assert os.path.exists(image_path), 
                    'Path does not exist: {}'.format(image_path)
            return image_path
    

    _load_xml_filenames(类似pascal_voc中的_load_image_set_index)

    def _load_xml_filenames(self):
            #从Annotations文件夹中拿取到bounding box annotation文件名
            #用来做数据集的索引
            xml_folder_path = os.path.join(self._data_path, "Annotations")
            assert os.path.exists(xml_folder_path), 
                'Path does not exist: {}'.format(xml_folder_path)
    
            for dirpath, dirnames, filenames in os.walk(xml_folder_path):
                    xml_filenames = [xml_filename.split(".")[0] for xml_filename in filenames]
    
            return xml_filenames
    

    gt_roidb

    def gt_roidb(self):
            #Ground-Truth 数据缓存
            cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
            if os.path.exists(cache_file):
                with open(cache_file, 'rb') as fid:
                    roidb = cPickle.load(fid)
                print '{} gt roidb loaded from {}'.format(self.name, cache_file)
                return roidb
    
            #从xml中获取Ground-Truth数据
            gt_roidb = [self._load_imagenet_annotation(xml_filename)
                        for xml_filename in self._image_index]
            with open(cache_file, 'wb') as fid:
                cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
            print 'wrote gt roidb to {}'.format(cache_file)
    
            return gt_roidb
    

    rpn_roidb

    def rpn_roidb(self):
            #根据gt_roidb生成rpn_roidb,并进行合并           
            gt_roidb = self.gt_roidb()
            rpn_roidb = self._load_rpn_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
    
            return roidb
    

    _load_rpn_roidb

    def _load_rpn_roidb(self, gt_roidb):
            filename = self.config['rpn_file']
            print 'loading {}'.format(filename)
            assert os.path.exists(filename), 
                   'rpn data not found at: {}'.format(filename)
            with open(filename, 'rb') as f:
                box_list = cPickle.load(f)
            return self.create_roidb_from_box_list(box_list, gt_roidb)
    

    _load_imagenet_annotation(类似于pascal_voc中的_load_pascal_annotation)

    def _load_imagenet_annotation(self, xml_filename):
            #从annotation的xml文件中拿取bounding box数据
            filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')
            #这里使用了ap,是我写的一个annotation parser,在后面贴出代码
            #它会返回这个xml文件的wnid, 图像文件名,以及里面包含的注解物体
            wnid, image_name, objects = ap.parse(filepath)
            num_objs = len(objects)
    
            boxes = np.zeros((num_objs, 4), dtype=np.uint16)
            gt_classes = np.zeros((num_objs), dtype=np.int32)
            overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
            seg_areas = np.zeros((num_objs), dtype=np.float32)
    
            # Load object bounding boxes into a data frame.
            for ix, obj in enumerate(objects):
                box = obj["box"]
                x1 = box['xmin']
                y1 = box['ymin']
                x2 = box['xmax']
                y2 = box['ymax']
                # 如果这个bounding box并不是我们想要学习的类别,那则跳过
                # go next if the wnid not exist in declared classes
                try:
                    cls = self._class_to_ind[obj["wnid"]]
                except KeyError:
                    print "wnid %s isn't show in given"%obj["wnid"]
                    continue
                boxes[ix, :] = [x1, y1, x2, y2]
                gt_classes[ix] = cls
                overlaps[ix, cls] = 1.0
                seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
    
            overlaps = scipy.sparse.csr_matrix(overlaps)
    
            return {'boxes' : boxes,
                    'gt_classes': gt_classes,
                    'gt_overlaps' : overlaps,
                    'flipped' : False,
                    'seg_areas' : seg_areas}
    

    annotation_parser.py文件

    import os
    import xml.dom.minidom
    
    def getText(node):
    	return node.firstChild.nodeValue
    
    def getWnid(node):
    	return getText(node.getElementsByTagName("name")[0])
    
    def getImageName(node):
    	return getText(node.getElementsByTagName("filename")[0])
    
    def getObjects(node):
    	objects = []
    	for obj in node.getElementsByTagName("object"):
    		objects.append({
    			"wnid": getText(obj.getElementsByTagName("name")[0]),
    			"box":{
    				"xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
    				"ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
    				"xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
    				"ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
    			}
    		})
    	return objects
    
    def parse(filepath):
    	dom = xml.dom.minidom.parse(filepath)
    	root = dom.documentElement
    	image_name = getImageName(root)
    	wnid = getWnid(root)
    	objects = getObjects(root)
    	
    	return wnid, image_name, objects
    

    则对数据结构的要求是:

    |---data
      |---imagenet
        |---Annotations
           |---n03147509
              |---n03147509_*.xml
              |---...
           |---n04272054
              |---n04272054_*.xml
              |---...
        |---Images
           |---n03147508_*.JPEG
           |---...
           |---n04272054_*.JPEG
           |---...
    

    同时我在github上也提供了draw方法,可以用来将bounding box画于Image文件上,用来甄别该annotation的正确性

    训练

    这样,我们的ImageNet类则是生成好了,下面我们则可以训练我们的数据,但是在开始之前,还有一件事情,那就是修改prototxt中的与类别数目有关的值,我将models/pascal_voc拷贝到了models/imagenet进行修改,比如我想要训练ZF,如果使用的是train_faster_rcnn_alt_opt.py,则需要修改models/imagenet/ZF/faster_rcnn_alt_opt/下的所有pt文件里的内容,用如下的法则去替换:

    //num为类别的个数
    input-data->num_classes = num
    class_score->num_output = num
    bbox_pred->num_output   = num*4
    

    我这里使用train_faster_rcnn_alt_opt.py进行的训练,这样的话则需要把添加的models/imagenet作为可选项

    //pt_type 则是添加的选择项,默认使用psacal_voc的models
    ./tools/train_faster_rcnn_alt_opt.py --gpu 0 
    --net_name ZF 
    --weights data/imagenet_models/ZF.v2.caffemodel[optional] 
    --imdb imagenet 
    --cfg experiments/cfgs/faster_rcnn_alt_opt.yml 
    --pt_type imagenet
    

    识别

    这里我们则需要使用刚训练出来的模型进行识别

    #就像demo.py一样,但是使用训练的models,我创建了tools/classify.py来单独识别
    prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
    caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/'+ NETS[args.demo_net][0] +'_faster_rcnn_final.caffemodel')
    

    同样,在识别前我们要对识别方法里的Classes进行修改,修改成你自己训练的类别后

    执行

    ./tools/classify.py --net zf
    

    则可对data/demo下的图片文件使用训练的zf网络进行识别

    Have fun

  • 相关阅读:
    Android ActivityGroup的使用代码将子activty 的layout加入到主activity中
    ERROR: Application requires API version 10. Device API version is 8
    简单实现Android实现九宫格
    继承中new 与 override的作用
    Sql server 数量累计求和
    Android 应用程序窗体显示状态操作(requestWindowFeature()的应用)
    UDP传输错误 无法找到程序集“client, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
    堆与栈的区别
    Decorator模式
    时间为O(1)删除节点的代码
  • 原文地址:https://www.cnblogs.com/zhiyishou/p/5651321.html
Copyright © 2011-2022 走看看