zoukankan      html  css  js  c++  java
  • torch 中各种图像格式转化

    • PIL:使用python自带图像处理库读取出来的图片格式
    • numpy:使用python-opencv库读取出来的图片格式
    • tensor:pytorch中训练时所采取的向量格式(当然也可以说图片)
    import torch
    from PIL import Image
    import matplotlib.pyplot as plt
    
    # loader使用torchvision中自带的transforms函数
    loader = transforms.Compose([
        transforms.ToTensor()])  
    
    unloader = transforms.ToPILImage()
    
    # 输入图片地址
    # 返回tensor变量
    def image_loader(image_name):
        image = Image.open(image_name).convert('RGB')
        image = loader(image).unsqueeze(0)#用来满足网络的输入维度的假batch维度,即不足之处补0
        return image.to(device, torch.float)
    
    # 输入PIL格式图片
    # 返回tensor变量
    def PIL_to_tensor(image):
        image = loader(image).unsqueeze(0)
        return image.to(device, torch.float)
    
    # 输入tensor变量
    # 输出PIL格式图片
    def tensor_to_PIL(tensor):
        image = tensor.cpu().clone()
        image = image.squeeze(0)#移除假batch维度,即删掉上面添加的0
        image = unloader(image)
        return image
    
    #直接展示tensor格式图片
    def imshow(tensor, title=None):
        image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
        image = image.squeeze(0)  # remove the fake batch dimension
        image = unloader(image)
        plt.imshow(image)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)  # pause a bit so that plots are updated
    
    #直接保存tensor格式图片
    def save_image(tensor, **para):
        dir = 'results'
        image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
        image = image.squeeze(0)  # remove the fake batch dimension
        image = unloader(image)
        if not osp.exists(dir):
            os.makedirs(dir)
        image.save('results_{}/s{}-c{}-l{}-e{}-sl{:4f}-cl{:4f}.jpg'
                   .format(num, para['style_weight'], para['content_weight'], para['lr'], para['epoch'],
                           para['style_loss'], para['content_loss']))
  • 相关阅读:
    Pytest框架之命令行参数2
    Pytest框架之命令行参数1
    [编程题] 二维数组中的查找
    [编程题]二叉树镜像
    补充基础:栈与队列模型
    6641. 【GDOI20205.20模拟】Sequence
    瞎讲:任意模数MTT
    瞎讲:FFT三次变二次优化
    小米oj 重拍数组求最大和
    小米oj 有多少个公差为2的等差数列
  • 原文地址:https://www.cnblogs.com/tingtin/p/12288619.html
Copyright © 2011-2022 走看看