zoukankan      html  css  js  c++  java
  • Pytorch中保存图片(tensor,cv2,pillow)

    tensor直接保存

    #!/usr/bin/env python
    # _*_ coding:utf-8 _*_
    import torch
    from torchvision import utils as vutils
     
     
    def save_image_tensor(input_tensor: torch.Tensor, filename):
        """
        将tensor保存为图片
        :param input_tensor: 要保存的tensor
        :param filename: 保存的文件名
        """
        assert (len(input_tensor.shape) == 4 and input_tensor.shape[0] == 1)
        # 复制一份
        input_tensor = input_tensor.clone().detach()
        # 到cpu
        input_tensor = input_tensor.to(torch.device('cpu'))
        # 反归一化
        # input_tensor = unnormalize(input_tensor)
        vutils.save_image(input_tensor, filename)

    tensor转cv2保存

    如果你是先转numpy,再交换维度,一定用transpose,而不是swapaxes,不然颜色会出问题= =

    就像下面这张图

    原图

    tensor转cv2保存 正确的代码

    #!/usr/bin/env python
    # _*_ coding:utf-8 _*_
    import torch
    import cv2
     
     
    def save_image_tensor2cv2(input_tensor: torch.Tensor, filename):
        """
        将tensor保存为cv2格式
        :param input_tensor: 要保存的tensor
        :param filename: 保存的文件名
        """
        assert (len(input_tensor.shape) == 4 and input_tensor.shape[0] == 1)
        # 复制一份
        input_tensor = input_tensor.clone().detach()
        # 到cpu
        input_tensor = input_tensor.to(torch.device('cpu'))
        # 反归一化
        # input_tensor = unnormalize(input_tensor)
        # 去掉批次维度
        input_tensor = input_tensor.squeeze()
        # 从[0,1]转化为[0,255],再从CHW转为HWC,最后转为cv2
        input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
        # RGB转BRG
        input_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR)
        cv2.imwrite(filename, input_tensor)

    tensor转pillow保存

    def save_image_tensor2pillow(input_tensor: torch.Tensor, filename):
        """
        将tensor保存为pillow
        :param input_tensor: 要保存的tensor
        :param filename: 保存的文件名
        """
        assert (len(input_tensor.shape) == 4 and input_tensor.shape[0] == 1)
        # 复制一份
        input_tensor = input_tensor.clone().detach()
        # 到cpu
        input_tensor = input_tensor.to(torch.device('cpu'))
        # 反归一化
        # input_tensor = unnormalize(input_tensor)
        # 去掉批次维度
        input_tensor = input_tensor.squeeze()
        # 从[0,1]转化为[0,255],再从CHW转为HWC,最后转为numpy
        input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
        # 转成pillow
        im = Image.fromarray(input_tensor)
        im.save(filename)
  • 相关阅读:
    MFC常用数据类型
    第四章 菜单、工具栏和状态栏(第8课)
    Bugku-CTF之文件包含2 (150)
    Bugku-CTF之本地包含( 60)
    Bugku-CTF之前女友(SKCTF)
    《Web安全攻防 渗透测试实战指南》 学习笔记 (四)
    Nmap工具用法详解
    《Web安全攻防 渗透测试实战指南 》 学习笔记 (三)
    《Web安全攻防 渗透测试实战指南》 学习笔记 (二)
    Bugku-CTF之PHP_encrypt_1(ISCCCTF) [fR4aHWwuFCYYVydFRxMqHhhCKBseH1dbFygrRxIWJ1UYFhotFjA=]
  • 原文地址:https://www.cnblogs.com/zhaoyingjie/p/14636217.html
Copyright © 2011-2022 走看看