zoukankan      html  css  js  c++  java
  • [神经网络]一步一步使用Mobile-Net完成视觉识别(五)

    1.环境配置

    2.数据集获取

    3.训练集获取

    4.训练

    5.调用测试训练结果

    6.代码讲解

      本文是第五篇,讲解如何调用测试训练结果。

    上一篇中我们输出了训练的模型,这一篇中我们通过调用训练好的模型来完成测试工作。

    在object_detection目录下创建test.py并输入以下内容:

    import os
    import cv2
    import numpy as np
    import tensorflow as tf
    import sys
    sys.path.append("..")
    from utils import label_map_util
    from utils import visualization_utils as vis_util
    
    ENERMY = 2 # 1 代表蓝色方,2 代表红色方 ,设置蓝色方为敌人
    DEBUG = False
    THRE_VAL = 0.2
    
    PATH_TO_CKPT ='/home/xueaoru/models/research/inference_graph_v2/frozen_inference_graph.pb'
    PATH_TO_LABELS = '/home/xueaoru/models/research/object_detection/car_label_map.pbtxt'
    NUM_CLASSES = 2
    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)
    detection_graph = tf.Graph()
    
    
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    
        sess = tf.Session(graph=detection_graph)
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    
    def video_test():
        #cap = cv2.VideoCapture(1)
        cap = cv2.VideoCapture("/home/xueaoru/下载/RoboMaster2.mp4")
        while(1):
            time = cv2.getTickCount()
            ret, image = cap.read()
            if ret!= True:
                break
            image_expanded = np.expand_dims(image, axis=0)#[1,w,h,3]
    
            (boxes, scores, classes, num) = sess.run(
            [detection_boxes, detection_scores, detection_classes, num_detections],
            feed_dict={image_tensor: image_expanded})
            #print(np.squeeze(classes).astype(np.int32))
            #print(np.squeeze(scores))
            #print(np.squeeze(boxes))
            vis_util.visualize_boxes_and_labels_on_image_array(
            image,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8,
            min_score_thresh=0.4)
    
            cv2.imshow('Object detector', image)
            key = cv2.waitKey(1)&0xff
            time = cv2.getTickCount() - time
            print("处理时间:"+str(time*1000/cv2.getTickFrequency()))
            if key ==27:
                break
        cv2.destroyAllWindows()
    def pic_test():
        image = cv2.imread("/home/xueaoru/models/research/images/image12.jpg")
        image_expanded = np.expand_dims(image, axis=0)  # [1,w,h,3]
    
        (boxes, scores, classes, num) = sess.run(
            [detection_boxes, detection_scores, detection_classes, num_detections],
            feed_dict={image_tensor: image_expanded})
        
        if DEBUG:
            vis_util.visualize_boxes_and_labels_on_image_array(
            image,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8,
            min_score_thresh=0.80)
        else:
            score = np.squeeze(scores)
            max_index = np.argmax(score)
            score = score[max_index]
            detected_class = np.squeeze(classes).astype(np.int32)[max_index]
            if score > THRE_VAL and detected_class == ENERMY:
                box = np.squeeze(boxes)[max_index]#(ymin,xmin,ymax,xmax)
                h,w,_ = image.shape
                min_point = (int(box[1]*w),int(box[0]*h))
                max_point = (int(box[3]*w),int(box[2]*h))
                cv2.rectangle(image,min_point,max_point,(0,255,255),2)
    
    
        
        cv2.imshow('Object detector', image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    video_test()

    好了,暂时就先这样吧,最后一篇详细讲解包括通过这些识别到的框到最后计算炮台偏转角度的代码。这段代码的讲解也放在后面。

  • 相关阅读:
    Kotlin使用常见问题汇总
    浅谈Kotlin(五): 静态变量&静态方法
    浅谈Kotlin(八):空安全、空类型
    浅谈Kotlin(七):lateinit、by lazy 使用
    浅谈Kotlin(六):data class的使用
    实例:([Flappy Bird Qlearning]
    强化学习之MountainCarContinuous(注册自己的gym环境)
    seq2seq之双向解码
    AwesomeVisualCaptioning
    VUE hash路由和history路由的区别
  • 原文地址:https://www.cnblogs.com/aoru45/p/9868350.html
Copyright © 2011-2022 走看看