zoukankan      html  css  js  c++  java
  • 【pytorch-ssd目标检测】可视化检测结果

    制作类似pascal voc格式的目标检测数据集:https://www.cnblogs.com/xiximayou/p/12546061.html

    训练自己创建的数据集:https://www.cnblogs.com/xiximayou/p/12546556.html

    验证自己创建的数据集:https://www.cnblogs.com/xiximayou/p/12550471.html

    测试自己创建的数据集:https://www.cnblogs.com/xiximayou/p/12550566.html

    还是以在谷歌colab上为例:

    cd /content/drive/My Drive/pytorch_ssd

    导入相应的包:

    import os
    import sys
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)
    
    import torch
    import torch.nn as nn
    import torch.backends.cudnn as cudnn
    from torch.autograd import Variable
    import numpy as np
    import cv2
    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    
    from ssd import build_ssd

    加载谷歌网盘:

    from google.colab import drive
    drive.mount('/content/drive')

    加载模型:

    net = build_ssd('test', 300, 3)    # initialize SSD
    net.load_weights('weights/ssd300_MASK_5000.pth')

    可视化要检测的图像:

    # image = cv2.imread('./data/example.jpg', cv2.IMREAD_COLOR)  # uncomment if dataset not downloaded
    %matplotlib inline
    from matplotlib import pyplot as plt
    from data import MASKDetection, MASK_ROOT, MASKAnnotationTransform
    # here we specify year (07 or 12) and dataset ('test', 'val', 'train') 
    mask_root="/content/drive/My Drive/pytorch_ssd"
    testset = MASKDetection(mask_root, "val", None, MASKAnnotationTransform())
    img_id = 2
    image = testset.pull_image(img_id)
    rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # View the sampled input image before transform
    plt.figure(figsize=(10,10))
    plt.imshow(rgb_image)
    plt.show()

    调整图片的格式:

    x = cv2.resize(image, (300, 300)).astype(np.float32)
    x -= (104.0, 117.0, 123.0)
    x = x.astype(np.float32)
    x = x[:, :, ::-1].copy()
    plt.imshow(x)
    x = torch.from_numpy(x).permute(2, 0, 1)

    使用模型进行预测:

    xx = Variable(x.unsqueeze(0))     # wrap tensor in Variable
    if torch.cuda.is_available():
        xx = xx.cuda()
    y = net(xx)

    输出结果:

    from data import MASK_CLASSES as labels
    top_k=3
    
    plt.figure(figsize=(10,10))
    colors = plt.cm.hsv(np.linspace(0, 1, 3)).tolist()
    plt.imshow(rgb_image)  # plot the image for matplotlib
    currentAxis = plt.gca()
    
    detections = y.data
    # scale each detection back up to the image
    scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
    for i in range(detections.size(1)):
        j = 0
        while detections[0,i,j,0] >= 0.6:
            score = detections[0,i,j,0]
            label_name = labels[i-1]
            display_txt = '%s: %.2f'%(label_name, score)
            pt = (detections[0,i,j,1:]*scale).cpu().numpy()
            coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1
            color = colors[i]
            currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
            currentAxis.text(pt[0], pt[1], display_txt, bbox={'facecolor':color, 'alpha':0.5})
            j+=1

    由于我的数据集中很少没有戴口罩的样本,因此没有戴口罩的AP较低。

    至此,使用pytorch-ssd训练测试自己数据集就全部完成啦。 

  • 相关阅读:
    【转载】C++针对ini配置文件读写大全
    CString向char类型转化 ---“=”: 无法从“wchar_t *”转换为“char *
    使用了非标准扩展:“xxx”使用 SEH,并且“xxx”有析构函数
    16进制串hex与ASCII字符串相互转换
    【转载】CCombobox使用大全
    获取c++ edit控件内容
    [转载]C++ CString与int 互转
    MacOS Cocos2d-x-3.2 创建HelloWorld项目
    构建之法阅读笔记6--敏捷开发2
    进度条--第十二周
  • 原文地址:https://www.cnblogs.com/xiximayou/p/12552854.html
Copyright © 2011-2022 走看看