zoukankan      html  css  js  c++  java
  • 【深度学习】PyTorch之Squeeze()和Unsqueeze()

    1. unsqueeze()

    该函数用来增加某个维度。在PyTorch中维度是从0开始的。

    import torch
    
    a = torch.arange(0, 9)
    print(a)

    结果:

    tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

    利用view()改变tensor的形状。值得注意的是view不会修改自身的数据,返回的新tensor与源tensor共享内存;同时必须保证前后元素总数一致。

    a = a.view(3, 3)
    print(f"a:{a} 
     shape:{a.shape}")

    结果:

    a:tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]]) 
     shape:torch.Size([3, 3])

    在第一个维度(即维度序号为0)前增加一个维度。

    a = a.unsqueeze(0)
    print(f"a:{a}
    shape:{a.shape}")

    结果:

    a:tensor([[[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]]])
    shape:torch.Size([1, 3, 3])

    同理,可在其他位置添加维度,在这里就不举例了。

    2. squeeze()

    该函数用来减少某个维度。

    print(f"1.   a:{a}
    shape:{a.shape}")
    a = a.unsqueeze(0)
    a = a.unsqueeze(2)
    print(f"2.   a:{a}
    shape:{a.shape}")
    a = a.squeeze(2)
    print(f"3.   a:{a}
    shape:{a.shape}")

    结果:

    1.   a:tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    shape:torch.Size([3, 3])
    2.   a:tensor([[[[0, 1, 2]],
    
             [[3, 4, 5]],
    
             [[6, 7, 8]]]])
    shape:torch.Size([1, 3, 1, 3])
    3.   a:tensor([[[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]]])
    shape:torch.Size([1, 3, 3])

    3. 下面是运用上述两个函数,并进行一次卷积的例子。

    from torchvision.transforms import  ToTensor
    import torch as t
    from torch import nnimport cv2
    import numpy as np
    import cv2
    to_tensor
    = ToTensor() # 加载图像 lena = cv2.imread('lena.jpg', cv2.IMREAD_GRAYSCALE) cv2.imshow('lena', lena) # input = to_tensor(lena) 将ndarray转换为tensor,自动将[0,255]归一化至[0,1]。 input = to_tensor(lena).unsqueeze(0) # 初始化卷积参数 kernel = t.ones(1, 1, 3, 3)/-9 kernel[:, :, 1, 1] = 1 conv = nn.Conv2d(1, 1, 3, 1, padding=1, bias=False) conv.weight.data = kernel.view(1, 1, 3, 3) # 输出 out = conv(input) out = out.squeeze(0) print(out.shape) out = out.unsqueeze(3) print(out.shape) out = out.squeeze(0) print(out.shape) out = out.detach().numpy()
    # 缩放到0~最大值 cv2.normalize(out, out,
    1.0, 0, cv2.NORM_INF) cv2.imshow("lena-result", out) cv2.waitKey()

    结果:

             

    torch.Size([1, 304, 304])
    torch.Size([1, 304, 304, 1])
    torch.Size([304, 304, 1])
    <class 'numpy.ndarray'> (304, 304, 1)

    参考文献

    [1] 陈云.深度学习框架之PyTorch入门与实践.北京:电子工业出版社,2018.

  • 相关阅读:
    给定一个无序数组arr,求出需要排序的最短子数组长度。例如: arr = [1,5,3,4,2,6,7] 返回4,因为只有[5,3,4,2]需要排序。
    Given n pairs of parentheses, write a function to generate all combinations of well-formed parentheses. For example, given n = 3, a solution set is: "((()))", "(()())", "(())()", "()(())", "()()()"
    shell数组
    学习ansible(一)
    nginx搭建简单直播服务器
    rsync
    Linux运维最常用150个命令
    Linux 三剑客
    学习Python(一)
    学习k8s(三)
  • 原文地址:https://www.cnblogs.com/chen-hw/p/11678949.html
Copyright © 2011-2022 走看看