zoukankan      html  css  js  c++  java
  • tensorflow利用预训练模型进行目标检测(三):将检测结果存入mysql数据库

    mysql版本:5.7 ; 数据库:rdshare;表captain_america3_sd用来记录某帧是否被检测。表captain_america3_d用来记录检测到的数据。

    python模块,包部分内容参考http://www.runoob.com/python/python-modules.html  https://www.cnblogs.com/ningskyer/articles/6025964.html

    一、连接数据库

    参考:

    # 将视频插入数据库
    def video_insert(filename,couse_id):
        conn =MySQLdb.connect(user='root',passwd='****',host='sh-cdb-myegtz7i.sql.tencentcdb.com',port=63619,db='bitbear',charset='utf8')
        cursor = conn.cursor()
    
        # 查找课程报告表中courseh_id等于解析得到的course_id的记录,得到courser_id
        # courseh_id是课程记录表中的course_id;courser_id是课程报告表中的主键;course_id是本程序中
        sql="SELECT courser_id FROM course_report WHERE courseh_id ='%s' "% (couse_id);
        cursor.execute(sql)
        results = cursor.fetchall()
        if(results):
            print(results)
            courser_id=results[0][0]
            print(results[0][0])
    
            # 获取该文件的路径
            #rarpath = os.getcwd();
            rarpath =filename
            print(rarpath)
    
            # 将记录插入
            #try:
            sql="UPDATE course_report SET json = '%s' WHERE courser_id = '%s' " % (rarpath,courser_id)
            cursor.execute(sql)
            cursor.rowcount
            conn.commit()
            cursor.close()
    View Code

    首先需要安装mysql驱动  sudo apt-get install python-mysqldb 

    安装完成之后可以在Python解释器中测试一下
    输入 import MySQLdb #注意大小写
    如果不报错,就证明安装成功了。
    简单测试版本
    # 将detection的结果存入mysql数据库
    def detection_to_database(object_name):
        conn =MySQLdb.connect(user='root',passwd='****',host='localhost',port=3306,db='rdshare',charset='utf8')
        cursor = conn.cursor()
    
    
        #sql="SELECT person FROM captain_america3_d WHERE id =1 ";
        #cursor.execute(sql)
        #results = cursor.fetchall()
        #if(results):
        #    print(results)
    
        sql="INSERT INTO captain_america3_sd (is_detected) VALUES (1)"
        cursor.execute(sql)
        cursor.rowcount
        conn.commit()
        cursor.close()
    View Code

     二、修改文件结构

    在同一目录下新建detection_control.py文件,相当于main文件,控制detection的流程,读入参数

    #!usr/bin/python
    # -*- coding: utf-8 -*-
    
    import datetime
    import os
    import time
    import argparse
    import detection as mod_detection
    import sys
    reload(sys)
    sys.setdefaultencoding('utf8')
    
    os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
    
    
    def parse_args():
        '''parse args'''
        parser = argparse.ArgumentParser()
        parser.add_argument('--image_path', default='/home/yanjieliu/rdshare/dataset/ca36000_36100/')
        parser.add_argument('--image_start_num', default='36000')
        parser.add_argument('--image_end_num', default='36002')
        parser.add_argument('--model_name',
                            default='ssd_inception_v2_coco_2018_01_28')
        return parser.parse_args()
    
    if __name__ == '__main__':
    # 运行
        args=parse_args()
        for frame_num in range(int(args.image_start_num),int(args.image_end_num)):
            print(frame_num)
            #调用detection.py文件中的Detection函数,并向其传递参数
            mod_detection.Detection(args, frame_num)
    View Code

    调用detection.py中的Detection函数,进行识别

    detection.py文件内容如下

    #!usr/bin/python
    # -*- coding: utf-8 -*-
    
    import numpy as np
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot 
    from matplotlib import pyplot as plt
    import os
    import tensorflow as tf
    from PIL import Image
    from object_detection.utils import label_map_util
    from object_detection.utils import visualization_utils as vis_util
    
    import datetime
    # 关闭tensorflow警告
    import time
    import MySQLdb
    import argparse
    import sys
    reload(sys)
    sys.setdefaultencoding('utf8')
    
    os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
    
    detection_graph = tf.Graph()
    
    
    # 将detection的结果存入mysql数据库
    def detection_to_database(object_name, frame_num):
        conn =MySQLdb.connect(user='root',passwd='****',host='localhost',port=3306,db='rdshare',charset='utf8')
        cursor = conn.cursor()
    
        #查询目标检测状态表,查看frame_num是否已经被检测过,若是,则更新,若否,则插入
        sql="SELECT is_detected FROM captain_america3_sd WHERE frame_num ='%s' "% (frame_num);
        cursor.execute(sql)
        results = cursor.fetchall()
        if(results):
            print(results)
            sql="UPDATE captain_america3_sd SET is_detected=1";
        else:
            print('null')
            sql="INSERT INTO captain_america3_sd (is_detected, frame_num) VALUES (1,'%s')"%(frame_num);
    
        cursor.execute(sql)
    
    
        cursor.rowcount
        conn.commit()
        cursor.close()
    
    
    # 加载模型数据-------------------------------------------------------------------------------------------------------
    def loading(model_name):
    
        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            PATH_TO_CKPT = '/home/yanjieliu/models/models/research/object_detection/pretrained_models/'+model_name + '/frozen_inference_graph.pb'
            with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph
    
    
    
    # Detection检测-------------------------------------------------------------------------------------------------------
    def load_image_into_numpy_array(image):
        (im_width, im_height) = image.size
        return np.array(image.getdata()).reshape(
            (im_height, im_width, 3)).astype(np.uint8)
    # List of the strings that is used to add correct label for each box.
    PATH_TO_LABELS = os.path.join('/home/yanjieliu/models/models/research/object_detection/data', 'mscoco_label_map.pbtxt')
    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)
    
    def Detection(args, frame_num):
        image_path=args.image_path
        loading(args.model_name)
        #start = time.time()
        with detection_graph.as_default():
            with tf.Session(graph=detection_graph) as sess:
                # for image_path in TEST_IMAGE_PATHS:
                image = Image.open('%simage-%s.jpeg'%(image_path, frame_num))
    
                # the array based representation of the image will be used later in order to prepare the
                # result image with boxes and labels on it.
                image_np = load_image_into_numpy_array(image)
    
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                image_np_expanded = np.expand_dims(image_np, axis=0)
                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    
                # Each box represents a part of the image where a particular object was detected.
                boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    
                # Each score represent how level of confidence for each of the objects.
                # Score is shown on the result image, together with the class label.
                scores = detection_graph.get_tensor_by_name('detection_scores:0')
                classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    
                # Actual detection.
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
    
                # Visualization of the results of a detection.将识别结果标记在图片上
                vis_util.visualize_boxes_and_labels_on_image_array(
                     image_np,
                     np.squeeze(boxes),
                     np.squeeze(classes).astype(np.int32),
                     np.squeeze(scores),
                     category_index,
                     use_normalized_coordinates=True,
                     line_thickness=8)
                # output result输出
                for i in range(3):
                    if classes[0][i] in category_index.keys():
                        class_name = category_index[classes[0][i]]['name']
                        detection_to_database(class_name, frame_num)
                    else:
                        class_name = 'N/A'
                    print("object:%s gailv:%s" % (class_name, scores[0][i]))
                    
                # matplotlib输出图片
                # Size, in inches, of the output images.
                IMAGE_SIZE = (20, 12)
                plt.figure(figsize=IMAGE_SIZE)
                plt.imshow(image_np)
                plt.show()
    
    def parse_args():
        '''parse args'''
        parser = argparse.ArgumentParser()
        parser.add_argument('--image_path', default='/home/yanjieliu/rdshare/dataset/ca36000_36100/')
        parser.add_argument('--image_start_num', default='36000')
        parser.add_argument('--image_end_num', default='36002')
        parser.add_argument('--model_name',
                            default='ssd_inception_v2_coco_2018_01_28')
        return parser.parse_args()
    
    
    
    if __name__ == '__main__':
    # 运行
        args=parse_args()
        start = time.time()
        Detection(args, frame_num)
        end = time.time()
        print('time:
    ')
        print str(end-start)
    
    
    
    
    #将时间写入到文件,方便统计
    #    with open('./outputs/1to10test_outputs.txt', 'a') as f:
    #        f.write('
    ')
    #        f.write(str(end-start))
    View Code
  • 相关阅读:
    JSTL EL 详解
    什么是JavaBean
    easy.jsp出现According to TLD or attribute directive in tag file, attribute value does not accept any expressions
    搭建sendmail
    系统运维工程师之路
    centos7 搭建安装zabbix3.0服务端实例(一)
    单例模式
    cassandra-压缩策略
    cassandra的坑-windows平台压缩策略
    Quick Sort
  • 原文地址:https://www.cnblogs.com/vactor/p/10031414.html
Copyright © 2011-2022 走看看