zoukankan      html  css  js  c++  java
  • yolo源码解析(三)

    七 测试网络

     模型测试包含于test.py文件,Detector类的image_detector()函数用于检测目标。

    复制代码
    import os
    import cv2
    import argparse
    import numpy as np
    import tensorflow as tf
    import yolo.config as cfg
    from yolo.yolo_net import YOLONet
    from utils.timer import Timer
    
    '''
    用于测试
    '''
    
    class Detector(object):
    复制代码

    1、类初始化函数

    复制代码
     def __init__(self, net, weight_file):
            '''
            构造函数
            利用 cfg 文件对网络参数进行初始化,
            其中 offset 的作用应该是一个定长的偏移
            boundery1和boundery2 作用是在输出中确定每种信息的长度(如类别,置信度等)。
            其中 boundery1 指的是对于所有的 cell 的类别的预测的张量维度,所以是 self.cell_size * self.cell_size * self.num_class
            boundery2 指的是在类别之后每个cell 所对应的 bounding boxes 的数量的总和,所以是self.boundary1 + self.cell_size * self.cell_size * self.boxes_per_cell
            
            
            args:
                net:YOLONet对象
                weight_file:检查点文件路径
            '''
            #yolo网络
            self.net = net
            #检查点文件路径
            self.weights_file = weight_file
            #输出文件夹路径
            self.output_dir = os.path.dirname(self.weights_file)
             #VOC 2012数据集类别名
            self.classes = cfg.CLASSES
            # #VOC 2012数据类别数
            self.num_class = len(self.classes)
            ##图像大小
            self.image_size = cfg.IMAGE_SIZE
            #单元格大小S
            self.cell_size = cfg.CELL_SIZE
            #每个网格边界框的个数B=2
            self.boxes_per_cell = cfg.BOXES_PER_CELL
            #阈值参数
            self.threshold = cfg.THRESHOLD
            #IoU 阈值参数
            self.iou_threshold = cfg.IOU_THRESHOLD
            '''#将网络输出分离为类别和置信度以及边界框的大小,输出维度为7*7*20 + 7*7*2 + 7*7*2*4=1470'''
             #7*7*20
            self.boundary1 = self.cell_size * self.cell_size * self.num_class
             #7*7*20+7*7*2
            self.boundary2 = self.boundary1 +
                self.cell_size * self.cell_size * self.boxes_per_cell
    
            #运行图之前,初始化变量
            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())
    
            #恢复模型
            print('Restoring weights from: ' + self.weights_file)
            self.saver = tf.train.Saver()
            #直接载入最近保存的检查点文件
            ckpt = tf.train.latest_checkpoint(self.output_dir)
            print("ckpt:",ckpt)         
            #如果存在检查点文件 则恢复模型
            if ckpt!=None:
                #恢复最近的检查点文件
                self.saver.restore(self.sess, ckpt) 
            else:
                #从指定检查点文件恢复
                self.saver.restore(self.sess, self.weights_file)
    复制代码

    2、draw_result()函数

    在原始图像上绘制边界框,并添加一些附件信息,如目标类别,置信度。

    复制代码
        def draw_result(self, img, result):
            '''
            在原图上绘制边界框,以及附加信息
            
            args:
                img:原始图片数据
                result:yolo网络目标检测到的边界框,list类型 每一个元素对应一个目标框 
                      包含{类别名,x_center,y_center,w,h,置信度} 
            '''
            #遍历每一个边界框
            for i in range(len(result)):
                #x_center
                x = int(result[i][1])
                #y_center
                y = int(result[i][2])
                #w/2
                w = int(result[i][3] / 2)
                #h/2
                h = int(result[i][4] / 2)
                #绘制矩形框(目标边界框) 矩形左上角,矩形右下角
                cv2.rectangle(img, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2)            
                #绘制矩形框,用于存放类别名称,使用灰度填充
                cv2.rectangle(img, (x - w, y - h - 20),
                              (x + w, y - h), (125, 125, 125), -1)
                #线型
                lineType = cv2.LINE_AA if cv2.__version__ > '3' else cv2.CV_AA
                #绘制文本信息 写上类别名和置信度
                cv2.putText(
                    img, result[i][0] + ' : %.2f' % result[i][5],
                    (x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                    (0, 0, 0), 1, lineType)
    复制代码

    3、detect()函数

    detect()函数用来对图像进行目标检测。

    复制代码
     def detect(self, img):
            '''
            图片目标检测
            
            args:
                img:原始图片数据
                
            return:
                result:返回检测到的边界框,list类型 每一个元素对应一个目标框 
                包含{类别名,x_center,y_center,w,h,置信度}
            '''
            #获取图片的高和宽
            img_h, img_w, _ = img.shape
            #图片缩放 [448,448,3]
            inputs = cv2.resize(img, (self.image_size, self.image_size))
            #BGR->RGB  uint->float32
            inputs = cv2.cvtColor(inputs, cv2.COLOR_BGR2RGB).astype(np.float32)
            #归一化处理 [-1.0,1.0]
            inputs = (inputs / 255.0) * 2.0 - 1.0
            #reshape [1,448,448,3]
            inputs = np.reshape(inputs, (1, self.image_size, self.image_size, 3))
    
            #获取网络输出第一项(即第一张图片) [1,1470]
            result = self.detect_from_cvmat(inputs)[0]
    
            #对检测的图片的边界框进行缩放处理,一张图片可以有多个边界框
            for i in range(len(result)):
                #x_center, y_center, w, h都是真实值,分别表示预测边界框的中心坐标,宽和高,都是浮点型
                result[i][1] *= (1.0 * img_w / self.image_size)    #x_center
                result[i][2] *= (1.0 * img_h / self.image_size)    #y_center
                result[i][3] *= (1.0 * img_w / self.image_size)    #w
                result[i][4] *= (1.0 * img_h / self.image_size)    #h
    
            #<class 'list'> 6 ['person', 405.83171163286482, 161.40340532575334, 166.17623397282193, 298.85661533900668, 0.69636690616607666]
            #Average detecting time: 0.571s
            print(type(result),len(result),result[0])
            return result
    复制代码

    4、detect_from_cvmat()函数

    复制代码
     def detect_from_cvmat(self, inputs):
            '''
            运行yolo网络,开始检测
            
            args:
                inputs:输入数据  [None,448,448,3]
                
            return:
                results:返回目标检测的结果,每一个元素对应一个测试图片,每个元素包含着若干个边界框
            
            '''
            #返回网络最后一层,激活函数处理之前的值  形状[None,1470]
            net_output = self.sess.run(self.net.logits,
                                       feed_dict={self.net.images: inputs})
            results = []
            
            #对网络输出每一行数据进行处理
            for i in range(net_output.shape[0]):
                results.append(self.interpret_output(net_output[i]))
    
            #返回处理后的结果
            return results
    复制代码

     5、interpret_output()函数

    该函数对yolo网络输出的结果进行处理,提取出有目标的边界框,方便后续的处理。

    复制代码
     def interpret_output(self, output):
            '''
            对yolo网络输出进行处理  
            
            args:
                output:yolo网络输出的每一行数据 大小为[1470,]
                        0:7*7*20:表示预测类别   
                        7*7*20:7*7*20 + 7*7*2:表示预测置信度,即预测的边界框与实际边界框之间的IOU
                        7*7*20 + 7*7*2:1470:预测边界框    目标中心是相对于当前格子的,宽度和高度的开根号是相对当前整张图像的(归一化的)        
                        
            return:
                 result:yolo网络目标检测到的边界框,list类型 每一个元素对应一个目标框 
                      包含{类别名,x_center,y_center,w,h,置信度}   实际上这个置信度是yolo网络输出的置信度confidence和预测对应的类别概率的乘积
            '''
            #[7,7,2,20]
            probs = np.zeros((self.cell_size, self.cell_size,
                              self.boxes_per_cell, self.num_class))
            #类别概率 [7,7,20]
            class_probs = np.reshape(
                output[0:self.boundary1],
                (self.cell_size, self.cell_size, self.num_class))
            #置信度 [7,7,2]
            scales = np.reshape(
                output[self.boundary1:self.boundary2],
                (self.cell_size, self.cell_size, self.boxes_per_cell))
            #边界框 [7,7,2,4]
            boxes = np.reshape(
                output[self.boundary2:],
                (self.cell_size, self.cell_size, self.boxes_per_cell, 4))
            #[14,7]  每一行[0,1,2,3,4,5,6]
            offset = np.array(
                [np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell)
            #[7,7,2] 每一行都是  [[0,0],[1,1],[2,2],[3,3],[4,4],[5,5],[6,6]]
            offset = np.transpose(
                np.reshape(
                    offset,
                    [self.boxes_per_cell, self.cell_size, self.cell_size]),
                (1, 2, 0))
    
            #目标中心是相对于整个图片的
            boxes[:, :, :, 0] += offset
            boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2))
            boxes[:, :, :, :2] = 1.0 * boxes[:, :, :, 0:2] / self.cell_size
            #宽度、高度相对整个图片的
            boxes[:, :, :, 2:] = np.square(boxes[:, :, :, 2:])
    
            #转换成实际的编辑框(没有归一化的)
            boxes *= self.image_size
    
            #遍历每一个边界框的置信度
            for i in range(self.boxes_per_cell):
                #遍历每一个类别
                for j in range(self.num_class):
                    #在测试时,乘以条件类概率和单个盒子的置信度预测,这些分数编码了j类出现在框i中的概率以及预测框拟合目标的程度。
                    probs[:, :, i, j] = np.multiply(
                        class_probs[:, :, j], scales[:, :, i])
    
            #[7,7,2,20] 如果第i个边界框检测到类别j 则[;,;,i,j]=1
            filter_mat_probs = np.array(probs >= self.threshold, dtype='bool')
            #返回filter_mat_probs非0值的索引 返回4个List,每个list长度为n  即检测到的边界框的个数      
            filter_mat_boxes = np.nonzero(filter_mat_probs)
            #获取检测到目标的边界框 [n,4]  n表示边界框的个数
            boxes_filtered = boxes[filter_mat_boxes[0],
                                   filter_mat_boxes[1], filter_mat_boxes[2]]        
            #获取检测到目标的边界框的置信度 (n,)
            probs_filtered = probs[filter_mat_probs]  
            #获取检测到目标的边界框对应的目标类别 (n,)
            classes_num_filtered = np.argmax(
                filter_mat_probs, axis=3)[
                filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]]    
            #按置信度倒序排序,返回对应的索引
            argsort = np.array(np.argsort(probs_filtered))[::-1]
            boxes_filtered = boxes_filtered[argsort]
            probs_filtered = probs_filtered[argsort]
            classes_num_filtered = classes_num_filtered[argsort]
    
            for i in range(len(boxes_filtered)):
                if probs_filtered[i] == 0:
                    continue
                for j in range(i + 1, len(boxes_filtered)):
                    #计算n各边界框,两两之间的IoU是否大于阈值,非极大值抑制                          
                    if self.iou(boxes_filtered[i], boxes_filtered[j]) :
                        probs_filtered[j] = 0.0
    
            #非极大值抑制后的输出
            filter_iou = np.array(probs_filtered > 0.0, dtype='bool')
            boxes_filtered = boxes_filtered[filter_iou]
            probs_filtered = probs_filtered[filter_iou]
            classes_num_filtered = classes_num_filtered[filter_iou]
    
            result = []
            #遍历每一个边界框
            for i in range(len(boxes_filtered)):
                result.append(
                    [self.classes[classes_num_filtered[i]],  #类别名
                     boxes_filtered[i][0],                   #x中心
                     boxes_filtered[i][1],                   #y中心
                     boxes_filtered[i][2],                   #宽度
                     boxes_filtered[i][3],                   #高度
                     probs_filtered[i]])                     #置信度  
    
            return result
    复制代码

    6、iou()函数

    计算两个边界框的IoU值。

    复制代码
        def iou(self, box1, box2):
            '''
            计算两个边界框的IoU
            
            args:
                box1:边界框1  [4,]   真实值
                box2:边界框2  [4,]   真实值
            '''
            tb = min(box1[0] + 0.5 * box1[2], box2[0] + 0.5 * box2[2]) - 
                max(box1[0] - 0.5 * box1[2], box2[0] - 0.5 * box2[2])
            lr = min(box1[1] + 0.5 * box1[3], box2[1] + 0.5 * box2[3]) - 
                max(box1[1] - 0.5 * box1[3], box2[1] - 0.5 * box2[3])
            inter = 0 if tb < 0 or lr < 0 else tb * lr
            return inter / (box1[2] * box1[3] + box2[2] * box2[3] - inter)
    复制代码

    7、camera_detector()函数

    调用摄像头实现实时目标检测。

    复制代码
        def camera_detector(self, cap, wait=10):
            '''
            打开摄像头,实时检测
            
            '''
            #测试时间
            detect_timer = Timer()
            #读取一帧
            ret, _ = cap.read()
    
            while ret:
                #读取一帧
                ret, frame = cap.read()
                #测试其实时间
                detect_timer.tic()
                result = self.detect(frame)
                #测试结束时间
                detect_timer.toc()
                print('Average detecting time: {:.3f}s'.format(
                    detect_timer.average_time))
                #绘制边界框,以及添加附加信息
                self.draw_result(frame, result)
                #显示
                cv2.imshow('Camera', frame)
                cv2.waitKey(wait)
    复制代码

    8、image_detector()函数

    对图片进行目标检测。

    复制代码
        def image_detector(self, imname, wait=0):
            '''
            目标检测
            
            args:
                imname:测试图片路径
            '''
            #检测时间
            detect_timer = Timer()
            #读取图片
            image = cv2.imread(imname)
            #image = cv2.resize(image,(int(image.shape[1]/2),int(image.shape[0]/2)))
            #检测的起始时间
            detect_timer.tic()
            #开始检测
            result = self.detect(image)
            #检测的结束时间
            detect_timer.toc()
            print('Average detecting time: {:.3f}s'.format(
                detect_timer.average_time))
            #绘制检测结果
            self.draw_result(image, result)
            cv2.imshow('Image', image)
            cv2.waitKey(wait)
    复制代码

    介绍完了Detector这个类,我们来看一下main函数。该函数比较检测,首先解析命令行参数,然后创建yolo网络,以及检测器对象,最后调用image_detector()函数对图片进行目标检测。

    复制代码
    def main():
        #创建一个解析器对象,并告诉它将会有些什么参数。当程序运行时,该解析器就可以用于处理命令行参数。
        #https://www.cnblogs.com/lovemyspring/p/3214598.html
        parser = argparse.ArgumentParser()
        #定义参数
        parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)
        parser.add_argument('--weight_dir', default='weights', type=str)
        parser.add_argument('--data_dir', default="data", type=str)
        parser.add_argument('--gpu', default='', type=str)
        #定义了所有参数之后,你就可以给 parse_args() 传递一组参数字符串来解析命令行。默认情况下,参数是从 sys.argv[1:] 中获取
        #parse_args() 的返回值是一个命名空间,包含传递给命令的参数。该对象将参数保存其属性
        args = parser.parse_args()
    
        #设置环境变量
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    
        #创建YOLO网络对象
        yolo = YOLONet(False)
        #加载检查点文件
        weight_file = os.path.join(args.data_dir, args.weight_dir, args.weights)
        weight_file = './data/pascal_voc/weights/YOLO_small.ckpt' 
        #weight_file = './data/pascal_voc/output/2018_07_09_17_00/yolo.ckpt-1000'
        
        #创建测试对象
        detector = Detector(yolo, weight_file)
    
        # detect from camera
        # cap = cv2.VideoCapture(-1)
        # detector.camera_detector(cap)
    
        # detect from image file
        imname = 'test/car.jpg'
        detector.image_detector(imname)
    复制代码

    我们执行如下代码,开始测试网络:

    if __name__ == '__main__':
        tf.reset_default_graph()
        main()

    我们可以看到yolo网络对小目标检测效果并不好,漏检了一个目标。这主要与yolo的网络结构以及损失函数有关。除此之外yolo网络还有一些其他缺点,我们总结如下:

    • 漏检。每个网格只预测一个类别的边界框,而且最后只取置信度最大的那个边界框。这就导致如果多个不同物体(或者同类物体的不同实体)的中心落在同一个网格中,会造成漏检。yolo对相互靠的很近的物体,还有很小的群体检测效果不好,这是因为一个网格中只预测了两个框,并且只属于一类。
    • 位置精准性差。召回率低。由于损失函数的问题,定位误差是影响检测效果的主要原因。尤其是大小物体的处理上,还有待加强。
    • 对测试图像中,同一类物体出现的新的不常见的长宽比和其他情况是。泛化能力偏弱。

    参考文章:

    [1]argparse - 命令行选项与参数解析(转)

    [2]Yolo v1详解及相关问题解答

  • 相关阅读:
    webservice 测试窗体只能用于来自本地计算机的请求
    未能加载文件或程序集system.web.extensions解决方法
    VS2010中水晶报表应用及实例
    存储过程
    Windows下wamp的配置问题(php初学者必看!!)
    IIS附加进程在Visual Studio 2010 中进行调试(高级)
    求职之(1)各公司待遇~~可能有点老了
    编译原理之(2)C++词法文件,语法文件
    STL笔记(4)关于erase,remove
    STL笔记(6)标准库:标准库中的排序算法
  • 原文地址:https://www.cnblogs.com/sddai/p/10288099.html
Copyright © 2011-2022 走看看