zoukankan      html  css  js  c++  java
  • TensorFlow SSD代码的运行,小的修改

    原始代码地址

    需要注意的地方:

    1.需要将checkpoint文件解压,修改代码中checkpoint目录为正确。

    2.需要修改img读取地址

    改动的地方:原始代码检测后图像分类是数字号,不能直接可读,如下

    修改代码后的结果如下:

    修改代码文件visualization.py即可。代码如下:(修改部分被注释包裹,主要是读list,按数字查key值,并显示。注意修改后需要关闭kernel再运行,否则运行结果不是新改动的)

    # Copyright 2017 Paul Balanca. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    import cv2
    import random
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    import matplotlib.cm as mpcm
    
    
    # =========================================================================== #
    # Some colormaps.
    # =========================================================================== #
    def colors_subselect(colors, num_classes=21):
        dt = len(colors) // num_classes
        sub_colors = []
        for i in range(num_classes):
            color = colors[i*dt]
            if isinstance(color[0], float):
                sub_colors.append([int(c * 255) for c in color])
            else:
                sub_colors.append([c for c in color])
        return sub_colors
    
    colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes=21)
    colors_tableau = [(255, 255, 255), (31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
                      (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
                      (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
                      (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
                      (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
    
    
    # =========================================================================== #
    # OpenCV drawing.
    # =========================================================================== #
    def draw_lines(img, lines, color=[255, 0, 0], thickness=2):
        """Draw a collection of lines on an image.
        """
        for line in lines:
            for x1, y1, x2, y2 in line:
                cv2.line(img, (x1, y1), (x2, y2), color, thickness)
    
    
    def draw_rectangle(img, p1, p2, color=[255, 0, 0], thickness=2):
        cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
    
    
    def draw_bbox(img, bbox, shape, label, color=[255, 0, 0], thickness=2):
        p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
        p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
        cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
        p1 = (p1[0]+15, p1[1])
        cv2.putText(img, str(label), p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.5, color, 1)
    
    
    def bboxes_draw_on_img(img, classes, scores, bboxes, colors, thickness=2):
        shape = img.shape
        for i in range(bboxes.shape[0]):
            bbox = bboxes[i]
            color = colors[classes[i]]
            # Draw bounding box...
            p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1]))
            p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1]))
            cv2.rectangle(img, p1[::-1], p2[::-1], color, thickness)
            # Draw text...
            s = '%s/%.3f' % (classes[i], scores[i])
            p1 = (p1[0]-5, p1[1])
            cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 0.4, color, 1)
    
    
    # =========================================================================== #
    # Matplotlib show...
    # modifed by wangjc,2017.10.18
    # =========================================================================== #
    
    
    
    def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5):
        """Visualize bounding boxes. Largely inspired by SSD-MXNET!
        """
    
        #################added
    
        def num2class(n):
            import tensorflow.models.SSD_Tensorflow_master.datasets.pascalvoc_2007 as pas
            x=pas.pascalvoc_common.VOC_LABELS.items()
            for name,item in x:
                if n in item:
                    #print(name)
                    return name
        ###########################added
    
        fig = plt.figure(figsize=figsize)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for i in range(classes.shape[0]):
            cls_id = int(classes[i])
            if cls_id >= 0:
                score = scores[i]
                #score = 0.01
                if cls_id not in colors:
                    colors[cls_id] = (random.random(), random.random(), random.random())
                ymin = int(bboxes[i, 0] * height)
                xmin = int(bboxes[i, 1] * width)
                ymax = int(bboxes[i, 2] * height)
                xmax = int(bboxes[i, 3] * width)
                rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                     ymax - ymin, fill=False,
                                     edgecolor=colors[cls_id],
                                     linewidth=linewidth)
                plt.gca().add_patch(rect)
                #class_name = str(cls_id)
                ###################added
                #class_name = ['haha','a','ss']
                class_name = num2class(cls_id)
                ##################added
                plt.gca().text(xmin, ymin - 2,
                               '{:s} | {:.3f}'.format(class_name, score),
                               bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                               fontsize=12, color='white')
        plt.show()
  • 相关阅读:
    AngularJs学习笔记--directive
    angularjs 路由(1)
    走进AngularJs(一)angular基本概念的认识与实战
    angularjs- 快速入门
    从angularJS看MVVM
    中软国际 问题一php的优缺点
    elasticsearch head安装后无法连接到es服务器问题
    Laravel5.3 流程粗粒度分析之bootstrap
    mysql执行大量sql语句
    Laravel RuntimeException inEncrypter.php line 43: The only supported ciphers are AES-128-CBC and AES-256-CBC with the correct key lengths
  • 原文地址:https://www.cnblogs.com/Osler/p/7687282.html
Copyright © 2011-2022 走看看