zoukankan      html  css  js  c++  java
  • Faster RCNN Tensorflow在测试时得到的result.txt文件

    需要准备三个路径:

    1、一个是进行测试时所使用的那些图片,找到其路径

    2、result.txt所在的路径

    3、生成图像的存放路径

      1 #!/usr/bin/env python
      2  
      3 
      4 from __future__ import absolute_import
      5 from __future__ import division
      6 from __future__ import print_function
      7  
      8 import _init_paths
      9 from model.config import cfg
     10 from model.test import im_detect
     11 from model.nms_wrapper import nms
     12  
     13 from utils.timer import Timer
     14 import tensorflow as tf
     15 import matplotlib.pyplot as plt
     16 from PIL import Image
     17 import numpy as np
     18 import os, cv2
     19 import argparse
     20  
     21  
     22 from nets.vgg16 import vgg16
     23 from nets.resnet_v1 import resnetv1
     24  
     25 CLASSES = ('__background__', 'dan','duo')
     26  
     27 NETS = {'vgg16': ('vgg16_faster_rcnn_iter_2000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
     28 DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
     29 
     30 
     31 def vis_detections(image_name, class_name, dets, thresh=0.7):
     32     """Draw detected bounding boxes."""
     33     inds = np.where(dets[:, -1] >= thresh)[0]
     34     if len(inds) == 0:
     35         return
     36     for i in inds:
     37         bbox = dets[i, :4]
     38         score = dets[i, -1]
     39         if(class_name == '__background__'):
     40             fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a')   
     41             fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'
    ')
     42             fw.close()
     43         elif(class_name == 'dan'):
     44             fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a')
     45             fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'
    ')
     46             fw.close()
     47         elif(class_name == 'duo'):
     48             fw = open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/result.txt','a')
     49             fw.write(str(image_name)+' '+class_name+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'
    ')
     50             fw.close()
     51  
     52  
     53 def demo(image_name, sess, net):
     54     im_file = os.path.join('/','home','bioinfo','Documents','pathonwork','lzh','tfasterrcnn', 'data', 'VOCdevkit2007', 'VOC2007', 'JPEGImages',image_name)
     55     #im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
     56     im = cv2.imread(im_file)
     57     # Detect all object classes and regress object bounds
     58     timer = Timer()
     59     timer.tic()
     60     scores, boxes = im_detect(sess, net, im)
     61     timer.toc()
     62     print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
     63     # Visualize detections for each class
     64     CONF_THRESH = 0.7
     65     thresh=0.7
     66     NMS_THRESH = 0.3
     67     im = im[:, :, (2, 1, 0)]
     68     fig, ax = plt.subplots(figsize=(12, 12))
     69     ax.imshow(im, aspect='equal', alpha=0.5)
     70     for cls_ind, cls in enumerate(CLASSES[1:]):
     71         cls_ind += 1 # because we skipped background
     72         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
     73         cls_scores = scores[:, cls_ind]
     74         dets = np.hstack((cls_boxes,
     75                           cls_scores[:, np.newaxis])).astype(np.float32)
     76         keep = nms(dets, NMS_THRESH)
     77         dets = dets[keep, :]
     78         vis_detections(image_name, cls, dets, thresh=CONF_THRESH)
     79         inds = np.where(dets[:, -1] >= thresh)[0]
     80         if len(inds) == 0:
     81             continue
     82         for i in inds:
     83             bbox = dets[i, :4]
     84             score = dets[i, -1]
     85             ax.add_patch(
     86                 plt.Rectangle((bbox[0], bbox[1]),
     87                               bbox[2] - bbox[0],
     88                               bbox[3] - bbox[1], fill=False,
     89                               edgecolor='red', linewidth=1.5)
     90                 )
     91             ax.text(bbox[0], bbox[1] - 2,
     92                 '{:s} {:.3f}'.format(cls, score),
     93                 bbox=dict(facecolor='blue', alpha=0.5),
     94                 fontsize=14, color='white')
     95     
     96     plt.axis('off')
     97     plt.tight_layout()
     98     plt.draw()
     99     image_name=image_name.replace('jpg','jpg')
    100     plt.savefig('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/results/'+image_name)
    101     print("save image to /home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/results/{}".format(image_name))
    102  
    103 def parse_args():
    104     """Parse input arguments."""
    105     parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    106     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
    107                         choices=NETS.keys(), default='vgg16')
    108     parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
    109                         choices=DATASETS.keys(), default='pascal_voc')
    110     args = parser.parse_args()
    111     
    112     return args
    113  
    114 if __name__ == '__main__':
    115     cfg.TEST.HAS_RPN = True # Use RPN for proposals
    116     args = parse_args()
    117     # model path
    118     demonet = args.demo_net
    119     dataset = args.dataset
    120     tfmodel = ('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/output/vgg16/voc_2007_trainval/default/vgg16_faster_rcnn_iter_2000.ckpt')
    121     
    122     if not os.path.isfile(tfmodel + '.meta'):
    123         raise IOError(('{:s} not found.
    Did you download the proper networks from '
    124                       'our server and place them properly?').format(tfmodel + '.meta'))
    125     # set config
    126     tfconfig = tf.ConfigProto(allow_soft_placement=True)
    127     tfconfig.gpu_options.allow_growth=True
    128     # init session
    129     sess = tf.Session(config=tfconfig)
    130     # load network
    131     if demonet == 'vgg16':
    132         net = vgg16()
    133     elif demonet == 'res101':
    134         net = resnetv1(num_layers=101)
    135     else:
    136         raise NotImplementedError
    137     net.create_architecture("TEST", 3,
    138                           tag='default', anchor_scales=[8, 16, 32])
    139     saver = tf.train.Saver()
    140     saver.restore(sess, tfmodel)
    141     print('Loaded network {:s}'.format(tfmodel))
    142     
    143     fi=open('/home/bioinfo/Documents/pathonwork/lzh/tfasterrcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt','r')
    144     txt=fi.readlines()
    145     im_names = []
    146     for line in txt:
    147         line=line.strip('
    ')
    148         line=line.replace('
    ','')
    149         line=(line+'.jpg')
    150         im_names.append(line)
    151     print(im_names)
    152     fi.close()
    153     for im_name in im_names:
    154         print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    155         print('Demo for data/demo/{}'.format(im_name))
    156         demo(im_name, sess, net)
    157     #plt.show()#

    生成结果

    参考文章:

    1.https://blog.csdn.net/zxs0222/article/details/89605300

    2.https://blog.csdn.net/gusui7202/article/details/83240212

    作者:舟华520

    出处:https://www.cnblogs.com/xfzh193/

    本文以学习,分享,研究交流为主,欢迎转载,请标明作者出处!

  • 相关阅读:
    Java内置包装类
    for循环思路题
    常用函数
    函数
    冒泡排序
    数组的运用
    for循环中有意思的练习题。
    for循环
    运算中容易出现的错误
    分支的运用
  • 原文地址:https://www.cnblogs.com/xfzh193/p/11934815.html
Copyright © 2011-2022 走看看