zoukankan      html  css  js  c++  java
  • 使用faster rcnn 跑vot2015的数据集

    本周老师给的任务:

    一是将VOT15数据集(世华已传到服务器上)上每个序列的第1,11,21,31,41帧分别运行Faster R-CNN检测器并保存在图片上显示的检测结果;

    二是将这5帧的ground truth bounding box作为proposal得到其对应的检测器分类结果(比如网络要检测20类物体,那包括背景就是得到21类对应的检测分数值),并将每个序列的检测结果分别存成一个文本文档。

     注意,使用代码的时候,可能会有路径错误,还可能是,我贴上的代码,博客园的网站给在某些语句后加了 <br> ,调错的时候细看!!我在后台竟然看不到<br>,但是浏览的时候却有!!

    第一个问题已经解决,现在整理一下思路。

    先将py faster rcnn 装好之后,测试运行dome.py能成功展示之后,再进行接下来的工作。

    我的想法是,

    (1)将vot2015数据集上的所有数据的分类统计出来(就是把vot2015下的子文件夹的名称统计出来,方便之后操作),这里直接用了( http://www.cnblogs.com/flyhigh1860/p/3896111.html )的源码进行修改

    #!/usr/bin/python
    # -*- coding:utf8 -*-
    
    import os
    allFileNum = 0
    
    def printPath(level, path):
        global allFileNum
        ''''' 
        打印一个目录下的所有文件夹和文件 
        '''
        # 所有文件夹,第一个字段是次目录的级别
        dirList = []
        # 所有文件
        fileList = []
        # 返回一个列表,其中包含在目录条目的名称(google翻译)
        files = os.listdir(path)
        # 先添加目录级别
        dirList.append(str(level))
        for f in files:
            if (os.path.isdir(path + '/' + f)):
                # 排除隐藏文件夹。因为隐藏文件夹过多
                if (f[0] == '.'):
                    pass
                else:
                    # 添加非隐藏文件夹
                    dirList.append(f)
            if (os.path.isfile(path + '/' + f)):
                # 添加文件
                fileList.append(f)
                # 当一个标志使用,文件夹列表第一个级别不打印
        i_dl = 0
      #得到的文件夹名保存在 save_file.txt 中,使用python的追加操作 ‘a’ save_file = open('/home/user/Downloads/save_file.txt','a') for dl in dirList: if (i_dl == 0): i_dl = i_dl + 1 else: # 打印至控制台,不是第一个的目录 print '-' * (int(dirList[0])), dl
           #将文件名写入save_file.txt中 save_file.write(dl) save_file.write(' ') # 打印目录下的所有文件夹和文件,目录级别+1 #printPath((int(dirList[0]) + 1), path + '/' + dl) for fl in fileList: # 打印文件 print '-' * (int(dirList[0])), fl # 随便计算一下有多少个文件 allFileNum = allFileNum + 1 if __name__ == '__main__': printPath(1, '/home/user/Downloads/vot2015') print '总文件数 =', allFileNum

     这里再给出save_file.txt 文件内容

    soldier
    butterfly
    hand
    car2
    sheep
    birds1
    motocross1
    marching
    book
    road
    graduate
    fish3
    fernando
    bag
    wiper
    gymnastics2
    leaves
    ball1
    birds2
    crossing
    soccer1
    godfather
    nature
    racing
    traffic
    pedestrian2
    handball2
    ball2
    gymnastics1
    singer2
    singer1
    dinosaur
    gymnastics3
    bolt1
    gymnastics4
    pedestrian1
    helicopter
    singer3
    matrix
    octopus
    iceskater1
    fish4
    sphere
    car1
    motocross2
    girl
    fish1
    bolt2
    basketball
    blanket
    bmx
    shaking
    tiger
    handball1
    rabbit
    fish2
    tunnel
    glove
    iceskater2
    soccer2
    

     (2)从save_file.txt 中将分来读取出来,保存再一个list中,之后将这段代码加到 demo.py 中使用(参考了  http://www.cnblogs.com/xuxn/archive/2011/07/27/read-a-file-with-python.html     和    http://www.cnblogs.com/mxh1099/p/5680001.html)

    l = []
    
    file = open('/home/user/Downloads/save_file.txt')
    
    while 1:
        line = file.readline()
        if line != '
    ':
            print line.replace("
    ", "")
         #在list中 加入去掉换行符的文件名 l.append(line.replace(" ","")) if not line: break print l

     (3)需要将文件名和要遍历的每个文件夹下的文件名配合,同样,这段代码之后会用在demo.py 中

    lfile = []
    
    file = open('/home/user/Downloads/save_file.txt')
    
    while 1:
        line = file.readline()
        if line != '
    ':
            lfile.append(line.replace("
    ", ""))
        if not line:
            break
    im_names =['00000023.jpg','00000011.jpg','00000001.jpg']
        # im_names = ['00000001.jpg', '000000011.jpg', '00000021.jpg',
        #             '00000031.jpg', '00000041.jpg']
    
    for litme in lfile :
        for im_name in im_names:
            im_path = str(litme) + '/' + str(im_name)
            print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
            #print 'Demo for data/demo/{}'.format(im_name)
            print im_path
    

     (4)可以对文件遍历之后,需要将生成的图片结果保存下来,参考了《演示如何实现Matplotlib绘图并保存图像但不显示图形的方法》(http://blog.csdn.net/rumswell/article/details/7342479) 和Python创建目录文件夹 (http://www.cnblogs.com/monsteryang/p/6574550.html)

    最后附上我修改之后的demo.py

    #!/usr/bin/env python
    
    # --------------------------------------------------------
    # Faster R-CNN
    # Copyright (c) 2015 Microsoft
    # Licensed under The MIT License [see LICENSE for details]
    # Written by Ross Girshick
    # --------------------------------------------------------
    
    """
    Demo script showing detections in sample images.
    
    See README.md for installation instructions before running.
    """
    
    import _init_paths
    from fast_rcnn.config import cfg
    from fast_rcnn.test import im_detect
    from fast_rcnn.nms_wrapper import nms
    from utils.timer import Timer
    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np
    import scipy.io as sio
    import caffe, os, sys, cv2
    import argparse
    
    #add
    matplotlib.use('Agg')
    
    CLASSES = ('__background__',
               'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair',
               'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant',
               'sheep', 'sofa', 'train', 'tvmonitor')
    
    NETS = {'vgg16': ('VGG16',
                      'VGG16_faster_rcnn_final.caffemodel'),
            'zf': ('ZF',
                      'ZF_faster_rcnn_final.caffemodel')}
    
    #add
    def mkdir(path):
        import os
    
        path = path.strip()
        path = path.rstrip("\")
    
        isExists = os.path.exists(path)
        if not isExists:
            os.makedirs(path)
            print path + 'ok'
            return True
        else:
    
            print path + 'failed!'
            return False
    
    def vis_detections(image_name, im, class_name, dets, thresh=0.5):
        """Draw detected bounding boxes."""
        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            return
    
        im = im[:, :, (2, 1, 0)]
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.imshow(im, aspect='equal')
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]
    
            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=3.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(class_name, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14, color='white')
    
        ax.set_title(('{} detections with '
                      'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                      thresh),
                      fontsize=14)
        plt.axis('off')
        plt.tight_layout()
        plt.draw()
        #add
        ll = []
        ll = str(image_name).split('/')
        print ll[0]
    
        mkdir('/home/user/tmp/' + str(ll[0]))
        plt.savefig('/home/user/tmp/' + str(image_name))
    
    def demo(net, image_name):
        """Detect object classes in an image using pre-computed object proposals."""
    
        # Load the demo image
        im_file = os.path.join(cfg.DATA_DIR, 'demo','vot2015', image_name)
        print("%s", im_file)
        im = cv2.imread(im_file)
    
        # Detect all object classes and regress object bounds
        timer = Timer()
        timer.tic()
        #add try except
        try:
            scores, boxes = im_detect(net, im)
            timer.toc()
            print ('Detection took {:.3f}s for '
                   '{:d} object proposals').format(timer.total_time, boxes.shape[0])
            # Visualize detections for each class
            CONF_THRESH = 0.8
            NMS_THRESH = 0.3
            for cls_ind, cls in enumerate(CLASSES[1:]):
                cls_ind += 1 # because we skipped background
                cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
                cls_scores = scores[:, cls_ind]
                dets = np.hstack((cls_boxes,
                                  cls_scores[:, np.newaxis])).astype(np.float32)
                keep = nms(dets, NMS_THRESH)
                dets = dets[keep, :]
                vis_detections(image_name,im, cls, dets, thresh=CONF_THRESH)
        except Exception:
            print 'Error'
    
    def parse_args():
        """Parse input arguments."""
        parser = argparse.ArgumentParser(description='Faster R-CNN demo')
        parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                            default=0, type=int)
        parser.add_argument('--cpu', dest='cpu_mode',
                            help='Use CPU mode (overrides --gpu)',
                            action='store_true')
        parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                            choices=NETS.keys(), default='vgg16')
    
        args = parser.parse_args()
    
        return args
    
    if __name__ == '__main__':
        cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    
        args = parse_args()
    
        prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
                                'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
        caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
                                  NETS[args.demo_net][1])
    
        if not os.path.isfile(caffemodel):
            raise IOError(('{:s} not found.
    Did you run ./data/script/'
                           'fetch_faster_rcnn_models.sh?').format(caffemodel))
    
        if args.cpu_mode:
            caffe.set_mode_cpu()
        else:
            caffe.set_mode_gpu()
            caffe.set_device(args.gpu_id)
            cfg.GPU_ID = args.gpu_id
        net = caffe.Net(prototxt, caffemodel, caffe.TEST)
    
        print '
    
    Loaded network {:s}'.format(caffemodel)
    
        # Warmup on a dummy image
        im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
        for i in xrange(2):
            _, _= im_detect(net, im)
    
        # im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
        #             '001763.jpg', '004545.jpg','00000023.jpg','00000011.jpg','00000001.jpg']
    
        # edit
        lfile = []
    
        file = open('/home/user/Downloads/save_file.txt')
    
        while 1:
            line = file.readline()
            if line != '
    ':
                lfile.append(line.replace("
    ", ""))
            if not line:
                break
    
        print lfile
    
        im_names = ['00000001.jpg', '00000011.jpg', '00000021.jpg',
                    '00000031.jpg', '00000041.jpg']
    
        for litme in lfile :
            for im_name in im_names:
                im_path = str(litme) + '/' + str(im_name)
                print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
                print 'Demo for data/demo/{}'.format(im_name)
                try:
                    demo(net, im_path)
                except Exception:
                    print 'ERROR'
        #plt.show()
    

    第二个问题先看着,没想法

     在图片上显示每个IOU大于0.5的proposal对应的最高检测值的类别、分数和回归后的框,在文本文档里则保存每个proposal对应的21个类别的检测分数和回归后的边界框坐标。

    对于每个类别,总会生成300个proposals,

    所以,在每个proposal,都会有4个坐标

    对于每个proposal,都会有一个类别值。

    因为要生成每个proposal对应的21个类别的分数,就需要将分数先保存起来,再输出

    还要记录回归后的边间框。

    对于图片,显示每个IOU大于0.5的proposal对应的最高检测值的类别、分数和回归后的框。

    也是先要将最高检测分数对应的类别和回归框记录下来。

  • 相关阅读:
    用SSH指令批量修改文件夹 文件权限和拥有者
    magento转移服务器和magento建立多站点总结
    ssh 命令
    magento缓存系列详解:clean cache
    如何配置magento免运费商品方法
    Magento后台订单显示产品图片的修改方法
    如何在magento后台增加一个自定义订单状态
    Magento路径函数getBaseUrl使用方法
    图解HTTPS
    php 数组 添加元素、删除元素
  • 原文地址:https://www.cnblogs.com/ya-cpp/p/7780455.html
Copyright © 2011-2022 走看看