zoukankan      html  css  js  c++  java
  • torchvision 批量可视化图片

    1.1 简介

    计算机视觉中,我们需要观察我们的神经网络输出是否合理。因此就需要进行可视化的操作。

    orchvision是独立于pytorch的关于图像操作的一些方便工具库。

    torchvision的详细介绍在:https://pypi.org/project/torchvision/0.1.8/

    这里主要使用的是make_grid函数,参数的tensor是一个 (B x C x H x W) - (Batchsize, Channel, Heigjt, Weight)的张量,nrow是输出图片网格的列数。padding是每张图片之间宽度间隔。

    make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)

    Example usage is given in this notebook<https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>

    举个例子。如果你的batch size 是一个(32,3,256,256)的一组图片,设置为nrow = 8,则最后输出的图片是一个4*8的网格,每个网格是一张图片。

    2.1 代码

    batch_image是([5, 3, 256, 256])大小的张量。

    batch_labels是 ([5, 15, 2]) 的坐标点。用于标记每张图中15个关键点的 [x, y] 坐标

    vis_flipped ([1, 5, 14]) 记录每个关键点可见的情况,0为不可见,1为可见

    output_root 是保存图片的路径

    i_loader是data loader 的索引

    j_loader是batch的索引

    代码的关键是要保存正确的关键的信息在一个大网格内,因此,需要把每个关键点的坐标,写一个for 循环。

    x = 行数*图片宽 + padding +x ,

    y = 列数*图片高 + padding +y

    import cv2
    import os
    import torchvision
    import numpy as npdef save_visualize_result(batch_image,labels,batch_labels,raw_image,vis_flipped,output_root,i_loader,j_batch):
        # batch_image.shape ([5, 3, 256, 256])
        # labels.shape ([1, 5, 15, 31, 31])
        # batch_labels.shape ([5,15,2])
        # raw_image.shape ([ 5, 3 , width_raw,height_raw ])
        # flipped_labels.shape ([1,5,28])[x1,x2,x3 ...,x14,y1,y2,y3...y14
        # vis_flipped [1, 5, 14]
        # i_loader -- which loader, j_batch -- which_batch
    
    
        batch_size, n_stages, n_joints = labels.shape[0], labels.shape[1], labels.shape[2]
        xmaps = n_stages
        ymaps = batch_size
    
        image_size = batch_image.shape[-2]
        label_size = labels.shape[-2]
        rotation = image_size / label_size
    
        grid = torchvision.utils.make_grid(batch_image, nrow=n_stages, padding=2, normalize=True)
        ndarr = grid.mul(255).clamp(0, 255).byte().cpu().permute(1, 2, 0).numpy()
        b, g, r = cv2.split(ndarr)
    
        ndarr = cv2.merge([r, g, b])
        ndarr = ndarr.copy()
        
        padding = 2
    
        height = int(batch_image.size(2) + padding)
        width = int(batch_image.size(3) + padding)
        k = 0
        # mpii_order = [13, 11, 9, 8, 10, 12, 4, 6, 14, 1, 7, 5, 3, 2]
        # transformed order [13, 11,  9,  8, 10, 12,  4,  6, 14,  1,  7,  5,  3,  2]
        names = ['ra', 'rk', 'rh', 'lh', 'lk', 'la', 'le', 'lw', 'neck', 'head', 'rw', 're', 'rs', 'ls']
    
        ### mapped ###
        k = 0
        for y in range(ymaps):
            for x in range(xmaps):
                raw_vis = vis_flipped[0, k, :]
                joints = batch_labels[k, :, :] * rotation
                for i_name, joint in enumerate(joints):
                    if i_name < 14:
                        if raw_vis[i_name] == 0:
                            continue
                        joint[0] = x * width + padding + joint[0]
                        joint[1] = y * height + padding + joint[1]
                        cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2)
                        cv2.putText(ndarr, names[i_name], org=(int(joint[0]), int(joint[1])),
                                    fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.5, color=[0, 0, 255])
                k = k + 1
        cv2.imwrite(os.path.join(output_root, 'loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png'), ndarr)
        print('loader_' + str(i_loader) + '_batch_' + str(j_batch) + '_mapped.png' + 'saved successfuly!')
     

    3.1 结果

  • 相关阅读:
    IntelliJ IDEA 常用快捷键
    solr4.5分组查询、统计功能介绍
    用于Lucene的各中文分词比较
    Lucene打分规则与Similarity模块详解
    Lucene
    tar中的参数 cvf,xvf,cvzf,zxvf的区别
    tmux 入门踩坑记录
    第一个shell脚本
    make 和 make install 的区别
    交叉编译
  • 原文地址:https://www.cnblogs.com/siyuan1998/p/10697600.html
Copyright © 2011-2022 走看看