zoukankan      html  css  js  c++  java
  • 通过类来实现多session 运行

    #xilerihua
    import tensorflow as tf
    import numpy as np
    import os
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    
    #objectlocation
    import six.moves.urllib as urllib
    import tarfile
    import matplotlib
    matplotlib.use('Agg')
    from collections import defaultdict
    from io import StringIO
    from matplotlib import pyplot as plt
    from PIL import Image
    from utils import label_map_util
    from utils import visualization_utils as vis_util
    import time
    
    class multi():
        """初始化所有模型"""
        def __init__(self):
            #  加载faster_rcnn 计算图
            self.faster_graph = tf.Graph()
            with self.faster_graph.as_default():
                self.od_graph_def2 = tf.GraphDef()
                with tf.gfile.GFile(r'E:/Project/TaoBaoLocation_new/research/object_detection/faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb', 'rb') as fid:
                    self.serialized_graph = fid.read()
                    self.od_graph_def2.ParseFromString(self.serialized_graph)
                    tf.import_graph_def(self.od_graph_def2, name='')
            self.faster_sess = tf.Session(graph=self.faster_graph)
    
            # 加载inception_v3计算图
            self.inception_graph = tf.Graph()
            with self.inception_graph.as_default():
                self.od_graph_def2 = tf.GraphDef()
                with tf.gfile.GFile(r'E:/Project/XiLeRiHuaReg/inception_v3_model/output_graph.pb', 'rb') as fid:
                    self.serialized_graph = fid.read()
                    self.od_graph_def2.ParseFromString(self.serialized_graph)
                    tf.import_graph_def(self.od_graph_def2, name='')
            self.inception_sess = tf.Session(graph=self.inception_graph)
    
    
        def get_result(self, type, image_path):
            if type == '2':
                #xilerihua
                lines = tf.gfile.GFile('E:/Project/XiLeRiHuaReg/inception_v3_model/output_labels.txt').readlines()
                uid_to_human = {}
                for uid, line in enumerate(lines):
                    line = line.strip('
    ')
                    uid_to_human[uid] = line
    
                def id_to_string(node_id):
                    if node_id not in uid_to_human:
                        return ''
                    return uid_to_human[node_id]
    
                softmax_tensor = self.inception_sess.graph.get_tensor_by_name('final_result:0')
    
                image_data = tf.gfile.GFile(image_path, 'rb').read()
                predictions = self.inception_sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
                predictions = np.squeeze(predictions)
    
                # image_path = os.path.join(sys.argv[2])
    
                top_k = predictions.argsort()[::-1][:1]  # 取前k个,此处取最相似的那个
    
                for node_id in top_k:  # 只取第一个
                    human_string = id_to_string(node_id)
                    score = predictions[node_id]
    
                human_kanji = {
                    'baby wipes': '婴儿湿巾',
                    'bath towel': '洗澡巾',
                    'convenient toothpick box': '便捷牙具盒',
                    'dish rack': '沥水架',
                    'hooks4': '挂钩粘钩4个装',
                    'kitchen towel': '厨房方巾',
                    'towel': '毛巾',
                    'macaron basin': '马卡龙家用多用盆',
                    'multi functional dental box': '多功能牙具盒',
                    'paring knife': '削皮刀',
                    'pineapple towel set': '菠萝纹毛巾浴巾套装',
                    'rubbish bag': '垃圾袋',
                    'sponge': '清洁海绵',
                    'stainless hook': '不锈钢多用挂钩',
                    'storage boxes': '三格储物盒',
                    'towel set': '毛巾浴巾套装',
                    'usb cable': '数据线',
                    'liquor': '劲酒'
                }
                thres = 0.6
                if score < thres:
                    print('不在17个范围之内')
                elif human_kanji[human_string] == '劲酒':
                    print('不在17个范围之内')
                else:
                    print(human_kanji[human_string])
    
            if type == '1':
    
                # List of the strings that is used to add correct label for each box.
                PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
    
                NUM_CLASSES = 90
    
                ##################### Loading label map
                # print('Loading label map...')
                label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
                categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                                            use_display_name=True)
                category_index = label_map_util.create_category_index(categories)
    
                ##################### Helper code
                def load_image_into_numpy_array(image):
                    (im_width, im_height) = image.size
                    return np.array(image.getdata()).reshape(
                        (im_height, im_width, 3)).astype(np.uint8)
    
                ##################### Detection
                # 测试图片的路径,可以根据自己的实际情况修改
                # TEST_IMAGE_PATH = 'test_images/image1.jpg'
                TEST_IMAGE_PATH = image_path
                # Size, in inches, of the output images.
                IMAGE_SIZE = (12, 8)
    
                # with tf.Session(graph=self.faster_graph) as self.faster_sess:
                    # print(TEST_IMAGE_PATH)
                image = Image.open(TEST_IMAGE_PATH)
                image_np = load_image_into_numpy_array(image)
                image_np_expanded = np.expand_dims(image_np, axis=0)
                image_tensor = self.faster_graph.get_tensor_by_name('image_tensor:0')
                boxes = self.faster_graph.get_tensor_by_name('detection_boxes:0')
                scores = self.faster_graph.get_tensor_by_name('detection_scores:0')
                classes = self.faster_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.faster_graph.get_tensor_by_name('num_detections:0')
    
                # Actual detection.
                (boxes, scores, classes, num_detections) = self.faster_sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
    
                scores = np.squeeze(scores)
                scores = scores.reshape((100, 1))
                boxes = np.squeeze(boxes)
                res = np.hstack((boxes, scores))
    
                # 筛选>thres的box
                thres = 0.55
                reserve_boxes_0 = []
                for b in res:
                    if b[-1]>thres:
                        reserve_boxes_0.append(b.tolist())
    
                # print('reserve_boxes_0:',reserve_boxes_0)
    
                #转换坐标
                reserve_boxes=[]
                w = image_np.shape[1]  #  1,3乘 1024
                h = image_np.shape[0]  #  0,2乘 636
                # print('h:',h,'w:',w)
    
                for box in reserve_boxes_0:
                    # print([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))],'tran')
                    # reserve_boxes.append([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))])
                    reserve_boxes.append([int(float(box[1]*w)),int(float(box[0]*h)),int(float(box[3]*w)),int(float(box[2]*h))])
    
                # print('reserve_boxes:',reserve_boxes)
    
                #没有找到一个框的情况
                if len(reserve_boxes)==0:#为0的情况,裁剪返回图片坐标
                    w_subtract = int(image_np.shape[1] / 10)
                    h_subtract = int(image_np.shape[0] / 10)
                    print(w_subtract, h_subtract, image_np.shape[1] - w_subtract, image_np.shape[0] - h_subtract)
                else:
                    # 保留最靠近中间的那个框的情况
                    # print('w:',image_np.shape[1],'h:',image_np.shape[0])
                    # 1.计算图片的中心点
                    # y:im.shape[0],x:im.shape[1]
                    x_center, y_center = image_np.shape[1] / 2, image_np.shape[0] / 2
                    # print(x_center,y_center)
    
    
                    # 2 计算找出来的框到中心点的距离
                    dis_l = []
                    for b in reserve_boxes:
                        b_xcenter, b_ycenter = int((b[0] + b[2]) / 2), int((b[1] + b[3]) / 2)
                        distance = np.sqrt((x_center - b_xcenter) ** 2 + (y_center - b_ycenter) ** 2)
                        dis_l.append(distance)
                        # print('b_xcenter,b_ycenter:',b_xcenter,b_ycenter,distance)
    
                    # 拿到最靠中心的box的index
                    center_index = dis_l.index(min(dis_l))
                    det = reserve_boxes[center_index]
                    print(det[0],det[1],det[2],det[3])
    
                    #可视化1
                    # cv2.rectangle(image_np, (det[0], det[1]), (det[2], det[3]), thickness=2, color=(0, 0, 255))
                    # cv2.imshow('res',image_np)
                    # cv2.waitKey(0)
                    # cv2.destroyAllWindows()
    
    
    
    
    #初始化
    multi = multi()
    
    for i in range(5):
        start_t=time.time()
        multi.get_result("1","1.jpg")
        end_t=time.time()
        print('t1:',end_t-start_t)
        multi.get_result("2","1.jpg")
        start_t3=time.time()
        print('t2:',start_t3-end_t)
    

      

  • 相关阅读:
    php排序算法-冒泡排序
    Mac安装java JDK
    mysql索引简单记录一下
    Mac 通过 pecl安装 redis 扩展
    Mac通过pecl安装swool时遇到的坑(root + openssl)
    php获取两个日期的之间的日期信息,返回数组
    2021.4.9训练
    王道数据结构代码:双向链表的操作
    王道数据结构代码:单链表的操作
    PTA 7-1 大炮打蚊子 (15 分)
  • 原文地址:https://www.cnblogs.com/liutianrui1/p/10914121.html
Copyright © 2011-2022 走看看