zoukankan      html  css  js  c++  java
  • torch 中各种图像格式转换

    • PIL:使用python自带图像处理库读取出来的图片格式
    • numpy:使用python-opencv库读取出来的图片格式
    • tensor:pytorch中训练时所采取的向量格式(当然也可以说图片)

    PIL与Tensor相互转换

    import torch
    from PIL import Image
    import matplotlib.pyplot as plt
    
    # loader使用torchvision中自带的transforms函数
    loader = transforms.Compose([
        transforms.ToTensor()])  
    
    unloader = transforms.ToPILImage()
    
    # 输入图片地址
    # 返回tensor变量
    def image_loader(image_name):
        image = Image.open(image_name).convert('RGB')
        image = loader(image).unsqueeze(0)
        return image.to(device, torch.float)
    
    # 输入PIL格式图片
    # 返回tensor变量
    def PIL_to_tensor(image):
        image = loader(image).unsqueeze(0)
        return image.to(device, torch.float)
    
    # 输入tensor变量
    # 输出PIL格式图片
    def tensor_to_PIL(tensor):
        image = tensor.cpu().clone()
        image = image.squeeze(0)
        image = unloader(image)
        return image
    
    #直接展示tensor格式图片
    def imshow(tensor, title=None):
        image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
        image = image.squeeze(0)  # remove the fake batch dimension
        image = unloader(image)
        plt.imshow(image)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)  # pause a bit so that plots are updated
    
    #直接保存tensor格式图片
    def save_image(tensor, **para):
        dir = 'results'
        image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
        image = image.squeeze(0)  # remove the fake batch dimension
        image = unloader(image)
        if not osp.exists(dir):
            os.makedirs(dir)
        image.save('results_{}/s{}-c{}-l{}-e{}-sl{:4f}-cl{:4f}.jpg'
                   .format(num, para['style_weight'], para['content_weight'], para['lr'], para['epoch'],
                           para['style_loss'], para['content_loss']))
    

    numpy 与 tensor相互转换

    import cv2
    import torch
    import matplotlib.pyplot as plt
    
    def toTensor(img):
        assert type(img) == np.ndarray,'the img type is {}, but ndarry expected'.format(type(img))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = torch.from_numpy(img.transpose((2, 0, 1)))
        return img.float().div(255).unsqueeze(0)  # 255也可以改为256
    
    def tensor_to_np(tensor):
        img = tensor.mul(255).byte()
        img = img.cpu().numpy().squeeze(0).transpose((1, 2, 0))
        return img
    
    def show_from_cv(img, title=None):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.figure()
        plt.imshow(img)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)
    
    
    def show_from_tensor(tensor, title=None):
        img = tensor.clone()
        img = tensor_to_np(img)
        plt.figure()
        plt.imshow(img)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)
    

    N张图片一起转换.

    # 将 N x H x W X C 的numpy格式图片转化为相应的tensor格式
    def toTensor(img):
        img = torch.from_numpy(img.transpose((0, 3, 1, 2)))
        return img.float().div(255).unsqueeze(0)
    

    参考:https://oldpan.me/archives/pytorch-tensor-image-transform

  • 相关阅读:
    宝宝的成长脚印9/2
    宝宝的成长脚印9/5
    手动作花灯10/6
    EasyUI中EasyLoader加载数组模块
    easyui常用属性
    VS2010如何在一个web项目中使用APP_CODE下的自定义类
    MSSQL系统常用视图命令及其作用
    db_autopwn渗透流程
    渗透测试工具Nmap从初级到高级
    EasyUI中在表单提交之前进行验证
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/11983437.html
Copyright © 2011-2022 走看看