zoukankan      html  css  js  c++  java
  • 学习Faster R-CNN代码demo(一)

    注释Yang Jianwei 的Faster R-CNN代码(PyTorch)

    jwyang’s github: https://github.com/jwyang/faster-rcnn.pytorch

    文件demo.py 

    这个文件是自己下载好训练好的模型后可执行

    下面是对代码的详细注释(直接在代码上注释):

    1.有关导入的库

     1 # --------------------------------------------------------
     2 # Tensorflow Faster R-CNN
     3 # Licensed under The MIT License [see LICENSE for details]
     4 # Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
     5 # --------------------------------------------------------
     6 
     7 #Python提供了__future__模块,把下一个新版本的特性导入到当前版本
     8 from __future__ import absolute_import#加入绝对引入这个新特性  引入系统的标准
     9 
    10 #导入python未来支持的语言特征division(精确除法),当我们没有在程序中导入该特征时,"/"操作符执行的是截断除法(Truncating Division),
    11 #当我们导入精确除法之后,"/"执行的是精确除法
    12 from __future__ import division
    13 
    14 #即使在python2.X,使用print就得像python3.X那样加括号使用
    15 from __future__ import print_function
    16 
    17 #_init_paths是指lib/model/_init_paths.py ?
    18 import _init_paths 
    19 import os #通过os模块调用系统命令
    20 import sys #sys 模块包括了一组非常实用的服务,内含很多函数方法和变量
    21 #numpy用来处理图片数据(多维数组), 尤其是numpy的broadcasting特性, 使得不同维度的数组可以一起操作(加,减,乘, 除, 等).
    22 import numpy as np
    23 import argparse #为py文件封装好可以选择的参数
    24 import pprint #提供了打印出任何python数据结构类和方法。
    25 import pdb #使用 Pdb调试 Python程序
    26 import time
    27 import cv2
    28 import torch
    29 #介绍autograde  https://www.jianshu.com/p/cbce2dd60120
    30 from torch.autograd import Variable#自动微分 vairable是tensor的一个外包装
    31 import torch.nn as nn
    32 import torch.optim as optim
    33 
    34 #为了方便加载以上五种数据库的数据,pytorch团队帮我们写了一个torchvision包。
    35 #使用torchvision就可以轻松实现数据的加载和预处理。
    36 import torchvision.transforms as transforms# transforms用于数据预处理
    37 import torchvision.datasets as dset
    38 
    39 #scipy.misc 下的图像处理
    40 #imread():返回的是 numpy.ndarray 也即 numpy 下的多维数组对象;
    41 from scipy.misc import imread
    42 
    43 from roi_data_layer.roidb import combined_roidb
    44 from roi_data_layer.roibatchLoader import roibatchLoader
    45 #demo.py运行过程中的配置基本上都在config.py了. 后续的代码流程中会用到这些配置值. 
    46 from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
    47 from model.rpn.bbox_transform import clip_boxes
    48 from model.nms.nms_wrapper import nms
    49 from model.rpn.bbox_transform import bbox_transform_inv
    50 from model.utils.net_utils import save_net, load_net, vis_detections
    51 from model.utils.blob import im_list_to_blob
    52 from model.faster_rcnn.vgg16 import vgg16
    53 from model.faster_rcnn.resnet import resnet
    54 import pdb
    55 
    56 try:
    57     xrange          # Python 2
    58 except NameError:
    59     xrange = range  # Python 3

    2.解析参数 parse_args()

     1 def parse_args():
     2   """
     3   Parse input arguments
     4   """
     5   parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
     6   parser.add_argument('--dataset', dest='dataset',#指代你跑得数据集名称,例如pascal-voc
     7                       help='training dataset',
     8                       default='pascal_voc', type=str)
     9   parser.add_argument('--cfg', dest='cfg_file',#配置文件
    10                       help='optional config file',
    11                       default='cfgs/vgg16.yml', type=str)
    12   parser.add_argument('--net', dest='net',#backbone网络类型
    13                       help='vgg16, res50, res101, res152',
    14                       default='res101', type=str)
    15   parser.add_argument('--set', dest='set_cfgs',#设置
    16                       help='set config keys', default=None,
    17                       nargs=argparse.REMAINDER)
    18   parser.add_argument('--load_dir', dest='load_dir',#模型目录
    19                       help='directory to load models',
    20                       default="/srv/share/jyang375/models")
    21   parser.add_argument('--image_dir', dest='image_dir',#图片目录
    22                       help='directory to load images for demo',
    23                       default="images")
    24   parser.add_argument('--cuda', dest='cuda',#是否用GPU
    25                       help='whether use CUDA',
    26                       action='store_true')
    27   parser.add_argument('--mGPUs', dest='mGPUs',#是不是多GPU
    28                       help='whether use multiple GPUs',
    29                       action='store_true')
    30     #class-agnostic 方式只回归2类bounding box,即前景和背景,
    31     #结合每个box在classification 网络中对应着所有类别的得分,以及检测阈值条件,就可以得到图片中所有类别的检测结果
    32   parser.add_argument('--cag', dest='class_agnostic',#是否class_agnostic回归
    33                       help='whether perform class_agnostic bbox regression',
    34                       action='store_true')
    35   parser.add_argument('--parallel_type', dest='parallel_type',#模型的哪一部分并行
    36                       help='which part of model to parallel, 0: all, 1: model before roi pooling',
    37                       default=0, type=int)
    38   parser.add_argument('--checksession', dest='checksession',
    39                       help='checksession to load model',
    40                       default=1, type=int)
    41   parser.add_argument('--checkepoch', dest='checkepoch',
    42                       help='checkepoch to load network',
    43                       default=1, type=int)
    44   #--checkpoint  a way to save the current state of your experiment so that you can pick up from where you left off.
    45   parser.add_argument('--checkpoint', dest='checkpoint',#跟保存模型有关
    46                       help='checkpoint to load network',
    47                       default=10021, type=int)
    48   parser.add_argument('--bs', dest='batch_size',#批大小
    49                       help='batch_size',
    50                       default=1, type=int)
    51   parser.add_argument('--vis', dest='vis',
    52                       help='visualization mode',#可视化模型
    53                       action='store_true')
    54   parser.add_argument('--webcam_num', dest='webcam_num',#好像就是网络哦摄像机
    55                       help='webcam ID number',
    56                       default=-1, type=int)
    57 
    58   #parse_args()是将之前add_argument()定义的参数进行赋值,并返回相关的namespace。
    59   args = parser.parse_args()
    60   return args
    61 
    62 lr = cfg.TRAIN.LEARNING_RATE#学习率
    63 momentum = cfg.TRAIN.MOMENTUM#动量
    64 weight_decay = cfg.TRAIN.WEIGHT_DECAY#权重衰减

    函数 _get_image_blob(im)

     1 def _get_image_blob(im):
     2 #这个函数其实就是读取图片,然后做尺寸变换,然后存储成矩阵的形式
     3   """Converts an image into a network input.
     4   Arguments:
     5     im (ndarray): a color image in BGR order
     6   Returns:
     7     blob (ndarray): a data blob holding an image pyramid 
     8     im_scale_factors (list): list of image scales (relative to im) used
     9       in the image pyramid
    10   """
    11   #Numpy中 astype:转换数组的数据类型。
    12   im_orig = im.astype(np.float32, copy=True)
    13   #而pixel mean的话,其实是把训练集里面所有图片的所有R通道像素,求了均值,G,B通道类似
    14   im_orig -= cfg.PIXEL_MEANS
    15 
    16   im_shape = im_orig.shape
    17   #所有元素中的min or max
    18   im_size_min = np.min(im_shape[0:2])#后面有可能有其他维度,这里留两维
    19   im_size_max = np.max(im_shape[0:2])
    20 
    21   processed_ims = []
    22   im_scale_factors = []
    23 
    24   for target_size in cfg.TEST.SCALES:#遍历cfg.TEST.SCALES这个元组或列表中的值
    25     im_scale = float(target_size) / float(im_size_min)#测试的尺度除以图像最小长度(宽高的最小值)
    26     # Prevent the biggest axis from being more than MAX_SIZE
    27     #防止最大值超过MAX_SIZE,round函数四舍五入
    28     if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
    29       im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
    30     #调整im_orig大小
    31     im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
    32             interpolation=cv2.INTER_LINEAR)
    33     #保存尺度值
    34     im_scale_factors.append(im_scale)
    35     #保存调整后的图像
    36     processed_ims.append(im)
    37 
    38   # Create a blob to hold the input images
    39   #创建一个blob来保存输入图像
    40   #这个函数出自这里 from model.utils.blob import im_list_to_blob
    41   blob = im_list_to_blob(processed_ims)#processed_ims是调整后的图像值
    42 
    43   return blob, np.array(im_scale_factors)

    4.主函数 if name == ‘main’:

      1 if __name__ == '__main__':
      2 
      3   args = parse_args()#这就是上面定义的那个函数
      4 
      5   print('Called with args:')
      6   print(args)
      7 
      8   if args.cfg_file is not None: #配置文件
      9     #model.utils.config 该文件中函数 """Load a config file and merge it into the default options."""
     10     cfg_from_file(args.cfg_file) #
     11   if args.set_cfgs is not None: #设置配置
     12     #model.utils.config文件中"""Set config keys via list (e.g., from command line)."""
     13     cfg_from_list(args.set_cfgs) #
     14     
     15   #Use GPU implementation of non-maximum suppression
     16   #解析参数是不是用GPU
     17   cfg.USE_GPU_NMS = args.cuda 
     18 
     19   print('Using config:')
     20   pprint.pprint(cfg)
     21   
     22   #设置随机数种子
     23   #每次运行代码时设置相同的seed,则每次生成的随机数也相同,
     24   #如果不设置seed,则每次生成的随机数都会不一样
     25   np.random.seed(cfg.RNG_SEED)
     26 
     27   # train set
     28   # -- Note: Use validation set and disable the flipped to enable faster loading.
     29   
     30   #load_dir 模型目录   args.net 网络   args.dataset 数据集
     31   input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
     32   if not os.path.exists(input_dir):
     33     #当程序出现错误,python会自动引发异常,也可以通过raise显示地引发异常。一旦执行了raise语句,raise后面的语句将不能执行。
     34     raise Exception('There is no input directory for loading network from ' + input_dir)
     35   
     36   #这里的三个check参数,是定义了训好的检测模型名称,例如训好的名称为faster_rcnn_1_20_10021,
     37   #代表了checksession = 1,checkepoch = 20, checkpoint = 10021,这样才可以读到模型“faster_rcnn_1_20_10021”
     38   load_name = os.path.join(input_dir,
     39     'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
     40 
     41   #PASCAL类别 1类背景 + 20类Object
     42   #array和asarray都可以将结构数据转化为ndarray,但是主要区别就是当数据源是ndarray时,
     43   #array仍然会copy出一个副本,占用新的内存,但asarray不会。
     44   pascal_classes = np.asarray(['__background__',
     45                        'aeroplane', 'bicycle', 'bird', 'boat',
     46                        'bottle', 'bus', 'car', 'cat', 'chair',
     47                        'cow', 'diningtable', 'dog', 'horse',
     48                        'motorbike', 'person', 'pottedplant',
     49                        'sheep', 'sofa', 'train', 'tvmonitor'])
     50 
     51   # initilize the network here.
     52   #class-agnostic 方式只回归2类bounding box,即前景和背景
     53   if args.net == 'vgg16':
     54     fasterRCNN = vgg16(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic)
     55   elif args.net == 'res101':
     56     fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic)
     57   elif args.net == 'res50':
     58     fasterRCNN = resnet(pascal_classes, 50, pretrained=False, class_agnostic=args.class_agnostic)
     59   elif args.net == 'res152':
     60     fasterRCNN = resnet(pascal_classes, 152, pretrained=False, class_agnostic=args.class_agnostic)
     61   else:
     62     print("network is not defined")
     63     #到了pdb.set_trace()那就会定下来,就可以看到调试的提示符(Pdb)了
     64     pdb.set_trace()
     65 
     66   fasterRCNN.create_architecture()#model.faster_rcnn.faster_rcnn.py 初始化模型 初始化权重
     67 
     68   print("load checkpoint %s" % (load_name))#模型路径
     69   if args.cuda > 0:#GPU
     70     checkpoint = torch.load(load_name)
     71   else:#CPU?
     72     ################################################################
     73     #在cpu上加载预先训练好的GPU模型,强制所有GPU张量在CPU中的方式:
     74     checkpoint = torch.load(load_name, map_location=(lambda storage, loc: storage))
     75   
     76   #the_model = TheModelClass(*args, **kwargs)
     77   #the_model.load_state_dict(torch.load(PATH))###恢复恢复
     78   fasterRCNN.load_state_dict(checkpoint['model'])#恢复模型
     79   if 'pooling_mode' in checkpoint.keys():
     80     cfg.POOLING_MODE = checkpoint['pooling_mode']#pooling方式
     81 
     82 
     83   print('load model successfully!')
     84 
     85   # pdb.set_trace()
     86 
     87   print("load checkpoint %s" % (load_name))
     88 
     89   # initilize the tensor holder here.
     90   #新建一些 一维Tensor
     91   im_data = torch.FloatTensor(1)
     92   im_info = torch.FloatTensor(1)
     93   num_boxes = torch.LongTensor(1)
     94   gt_boxes = torch.FloatTensor(1)
     95 
     96   # ship to cuda
     97   if args.cuda > 0:#如果用GPU,张量放到GPU上
     98     im_data = im_data.cuda()
     99     im_info = im_info.cuda()
    100     num_boxes = num_boxes.cuda()
    101     gt_boxes = gt_boxes.cuda()
    102 
    103   # make variable
    104   #ariable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,
    105   #那么所有依赖它的节点volatile属性都为True。
    106   #volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。
    107   im_data = Variable(im_data, volatile=True)
    108   im_info = Variable(im_info, volatile=True)
    109   num_boxes = Variable(num_boxes, volatile=True)
    110   gt_boxes = Variable(gt_boxes, volatile=True)
    111 
    112   if args.cuda > 0:
    113     cfg.CUDA = True
    114 
    115   if args.cuda > 0:
    116     fasterRCNN.cuda()
    117 
    118   #model.eval(),让model变成测试模式,
    119   #对dropout和batch normalization的操作在训练和测试的时候是不一样的
    120   #pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值
    121   fasterRCNN.eval()
    122 
    123   #通过time()函数可以获取当前的时间
    124   start = time.time()
    125   max_per_image = 100
    126   thresh = 0.05
    127   vis = True
    128 
    129   webcam_num = args.webcam_num
    130   # Set up webcam or get image directories
    131   if webcam_num >= 0 :#应该就是判断要不要自己用电脑录视频
    132     #cap = cv2.VideoCapture(0) 打开笔记本的内置摄像头。
    133     #cap = cv2.VideoCapture('D:output.avi') 打开视频文件
    134     cap = cv2.VideoCapture(webcam_num)
    135     num_images = 0
    136   else:#如果不用电脑录视频,那么就读取image路径下的图片
    137     #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
    138     #这个列表以字母顺序
    139     imglist = os.listdir(args.image_dir)
    140     num_images = len(imglist)#有多少张图片
    141 
    142   print('Loaded Photo: {} images.'.format(num_images))
    143 
    144 
    145   while (num_images >= 0):
    146       total_tic = time.time()#当前时间
    147       if webcam_num == -1:#如果不用摄像头
    148         num_images -= 1
    149 
    150       # Get image from the webcam
    151       #从电脑摄像头读取图片
    152       if webcam_num >= 0:
    153         if not cap.isOpened():#摄像头开启失败
    154           raise RuntimeError("Webcam could not open. Please check connection.")
    155         
    156         #ret 为True 或者False,代表有没有读取到图片
    157         #frame表示截取到一帧的图片
    158         ret, frame = cap.read()
    159         
    160         #摄像头截取到一帧的图片 存储为numpy数组
    161         im_in = np.array(frame)
    162       # Load the demo image
    163       else:
    164         #图片路径
    165         im_file = os.path.join(args.image_dir, imglist[num_images])
    166         # im = cv2.imread(im_file)
    167         #读取的的图片 存储为numpy数组
    168         im_in = np.array(imread(im_file))
    169       if len(im_in.shape) == 2:
    170         #np.newaxis的作用就是在这一位置增加一个一维,
    171         #这一位置指的是np.newaxis所在的位置,比较抽象,需要配合例子理解。
    172         #####################example
    173         #x1 = np.array([1, 2, 3, 4, 5])
    174         # the shape of x1 is (5,)
    175         #x1_new = x1[:, np.newaxis]
    176         # now, the shape of x1_new is (5, 1)
    177         # array([[1],
    178         #        [2],
    179         #        [3],
    180         #        [4],
    181         #        [5]])
    182         #x1_new = x1[np.newaxis,:]
    183         # now, the shape of x1_new is (1, 5)
    184         # array([[1, 2, 3, 4, 5]])
    185         #####################
    186         im_in = im_in[:,:,np.newaxis]#变为二维?
    187         
    188         #数组拼接
    189         #若axis=0,则要求除了a.shape[0]和b.shape[0]可以不等之外,其它维度必须相等
    190         #若axis=0,则要求除了a.shape[0]和b.shape[0]可以不等之外,其它维度必须相等
    191         #axis>=2 的情况以此类推,axis的值必须小于数组的维度
    192         im_in = np.concatenate((im_in,im_in,im_in), axis=2)
    193         
    194       # rgb -> bgr
    195       #line[:-1]其实就是去除了这行文本的最后一个字符(换行符)后剩下的部分。
    196       #line[::-1]字符串反过来 line = "abcde" line[::-1] 结果为:'edcba'
    197       im = im_in[:,:,::-1]#RGB->BGR
    198 
    199       blobs, im_scales = _get_image_blob(im)#图片变换 该文件上面定义的函数,返回处理后的值 和尺度
    200       assert len(im_scales) == 1, "Only single-image batch implemented"
    201       im_blob = blobs#处理后的值
    202       #图像信息,长、宽、尺度
    203       im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)
    204 
    205       #从numpy变为Tensor
    206       im_data_pt = torch.from_numpy(im_blob)
    207       #permute 将tensor的维度换位。
    208       #参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。
    209       #把索引为3的张量位置给提到前面了,例如128 128 3的图片变为 3 128 128
    210       im_data_pt = im_data_pt.permute(0, 3, 1, 2)
    211       #图像信息也变为tensor
    212       im_info_pt = torch.from_numpy(im_info_np)
    213 
    214       #将tensor的大小调整为指定的大小。
    215       #如果元素个数比当前的内存大小大,就将底层存储大小调整为与新元素数目一致的大小。
    216       im_data.data.resize_(im_data_pt.size()).copy_(im_data_pt)
    217       im_info.data.resize_(im_info_pt.size()).copy_(im_info_pt)
    218       gt_boxes.data.resize_(1, 1, 5).zero_()
    219       num_boxes.data.resize_(1).zero_()
    220 
    221       # pdb.set_trace()
    222       det_tic = time.time()#当前时间
    223 
    224       #参数带入模型
    225       #rois: 兴趣区域,怎么表示???????????
    226         # rois blob: holds R regions of interest, each is a 5-tuple
    227         # (n, x1, y1, x2, y2) specifying an image batch index n and a
    228         # rectangle (x1, y1, x2, y2)
    229         # top[0].reshape(1, 5)
    230       #cls_prob: softmax得到的概率值
    231       #bbox_pred: 偏移
    232       #rpn_loss_cls分类损失,计算softmax的损失,输入labels和cls layer的18个输出(中间reshape了一下),输出损失函数的具体值
    233       #rpn_loss_box 计算的框回归损失函数具体的值
    234       rois, cls_prob, bbox_pred, 
    235       rpn_loss_cls, rpn_loss_box, 
    236       RCNN_loss_cls, RCNN_loss_bbox, 
    237       rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
    238 
    239       scores = cls_prob.data#分类概率值
    240       ###################################################
    241       #boxes包含框的坐标
    242       #各维度表示什么??????????
    243       boxes = rois.data[:, :, 1:5]#?????????????????????
    244 
    245       if cfg.TEST.BBOX_REG:#Train bounding-box regressors TRUE or FALSE
    246           # Apply bounding-box regression deltas
    247           box_deltas = bbox_pred.data#偏移值
    248           if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
    249           # Optionally normalize targets by a precomputed mean and stdev
    250             if args.class_agnostic:
    251                 if args.cuda > 0:
    252                     #box_deltas.view改变维度
    253                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() 
    254                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
    255                 else:
    256                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) 
    257                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
    258 
    259                 box_deltas = box_deltas.view(1, -1, 4)
    260             else:
    261                 if args.cuda > 0:
    262                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() 
    263                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
    264                 else:
    265                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) 
    266                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
    267                 box_deltas = box_deltas.view(1, -1, 4 * len(pascal_classes))
    268 
    269          #model.rpn.bbox_transform 根据anchor和偏移量计算proposals
    270          #最后返回的是左上和右下顶点的坐标[x1,y1,x2,y2]。
    271          pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
    272          #model.rpn.bbox_transform 
    273          #将改变坐标信息后超过图像边界的框的边框裁剪一下,使之在图像边界之内
    274           pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
    275       else:
    276           # Simply repeat the boxes, once for each class
    277           #Numpy的 tile() 函数,就是将原矩阵横向、纵向地复制,这里是横向
    278           pred_boxes = np.tile(boxes, (1, scores.shape[1]))
    279 
    280       pred_boxes /= im_scales[0]
    281 
    282       #squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
    283       scores = scores.squeeze()
    284       pred_boxes = pred_boxes.squeeze()
    285       det_toc = time.time()#当前时间
    286       detect_time = det_toc - det_tic#detect_time
    287       misc_tic = time.time()
    288       if vis:
    289           im2show = np.copy(im)
    290       for j in xrange(1, len(pascal_classes)):#所有类别
    291           #torch.nonzero
    292           #返回一个包含输入input中非零元素索引的张量,输出张量中的每行包含输入中非零元素的索引
    293           #若输入input有n维,则输出的索引张量output形状为z * n, 这里z是输入张量input中所有非零元素的个数
    294           inds = torch.nonzero(scores[:,j]>thresh).view(-1)#参数中的-1就代表这个位置由其他位置的数字来推断
    295           # if there is det
    296           #torch.numel() 返回一个tensor变量内所有元素个数,可以理解为矩阵内元素的个数
    297           if inds.numel() > 0:
    298             cls_scores = scores[:,j][inds]
    299             #torch.sort(input, dim=None, descending=False, out=None)有true,则表示降序,默认升序
    300             _, order = torch.sort(cls_scores, 0, True)#沿第0列降序
    301             if args.class_agnostic:#两类
    302               cls_boxes = pred_boxes[inds, :]
    303             else:
    304               cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]#why???
    305             
    306             #按行连接起来,torch.unsqueeze()这个函数主要是对数据维度进行扩充
    307             cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
    308             # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
    309             cls_dets = cls_dets[order]
    310             #model.nms.nms_wrapper
    311             keep = nms(cls_dets, cfg.TEST.NMS, force_cpu=not cfg.USE_GPU_NMS)
    312             cls_dets = cls_dets[keep.view(-1).long()]
    313             if vis:
    314               #model.utils.net_utils
    315               im2show = vis_detections(im2show, pascal_classes[j], cls_dets.cpu().numpy(), 0.5)
    316 
    317       misc_toc = time.time()
    318       nms_time = misc_toc - misc_tic
    319 
    320       if webcam_num == -1:
    321           #当我们使用print(obj)在console上打印对象的时候,实质上调用的是sys.stdout.write(obj+'
    ')
    322           sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s   
    ' 
    323                            .format(num_images + 1, len(imglist), detect_time, nms_time))
    324           sys.stdout.flush()
    325 
    326       if vis and webcam_num == -1:
    327           # cv2.imshow('test', im2show)
    328           # cv2.waitKey(0)
    329           result_path = os.path.join(args.image_dir, imglist[num_images][:-4] + "_det.jpg")
    330           cv2.imwrite(result_path, im2show)
    331       else:
    332           im2showRGB = cv2.cvtColor(im2show, cv2.COLOR_BGR2RGB)
    333           cv2.imshow("frame", im2showRGB)
    334           total_toc = time.time()
    335           total_time = total_toc - total_tic
    336           frame_rate = 1 / total_time
    337           print('Frame rate:', frame_rate)
    338           if cv2.waitKey(1) & 0xFF == ord('q'):
    339               break
    340   if webcam_num >= 0:
    341       cap.release()
    342       cv2.destroyAllWindows()

    REF:YF-Zhang

  • 相关阅读:
    [独孤九剑]Oracle知识点梳理(五)数据库常用对象之Table、View
    [独孤九剑]Oracle知识点梳理(四)SQL语句之DML和DDL
    [独孤九剑]Oracle知识点梳理(三)导入、导出
    [独孤九剑]Oracle知识点梳理(二)数据库的连接
    [独孤九剑]Oracle知识点梳理(一)表空间、用户
    [独孤九剑]Oracle知识点梳理(零)目录
    jmeter安装
    MongoDB 用Robomong可视化工具操作的 一些简单语句
    限制输入字数JS
    我们来谈谈最近最热门的微信小程序
  • 原文地址:https://www.cnblogs.com/wind-chaser/p/11353466.html
Copyright © 2011-2022 走看看