zoukankan      html  css  js  c++  java
  • Reshapeing operations

    Reshapeing operations

    Suppose we have the following tensor:

    t = torch.tensor([
        [1,1,1,1],
        [2,2,2,2],
        [3,3,3,3]
    ], dtype=torch.float32)
    

    We have two ways to get the shape:

    > t.size()
    torch.Size([3, 4])
    
    > t.shape
    torch.Size([3, 4])
    

    The rank of a tensor is equal to the length of the tensor's shape.

    > len(t.shape)
    2
    

    We can also deduce the number of elements contained within the tensor.

    > torch.tensor(t.shape).prod()
    tensor(12)
    

    In PyTorch, there is a dedicated function for this:

    > t.numel()
    12
    

    Reshaping a tensor in PyTorch

    > t.reshape([2,6])
    tensor([[1., 1., 1., 1., 2., 2.],
            [2., 2., 3., 3., 3., 3.]])
    
    > t.reshape([3,4])
    tensor([[1., 1., 1., 1.],
            [2., 2., 2., 2.],
            [3., 3., 3., 3.]])
    
    > t.reshape([4,3])
    tensor([[1., 1., 1.],
            [1., 2., 2.],
            [2., 2., 3.],
            [3., 3., 3.]])
    
    > t.reshape(6,2)
    tensor([[1., 1.],
            [1., 1.],
            [2., 2.],
            [2., 2.],
            [3., 3.],
            [3., 3.]])
    
    > t.reshape(12,1)
    tensor([[1.],
            [1.],
            [1.],
            [1.],
            [2.],
            [2.],
            [2.],
            [2.],
            [3.],
            [3.],
            [3.],
            [3.]])
    

    In this example, we increase the rank to 3 :

    > t.reshape(2,2,3)
    tensor(
    [
        [
            [1., 1., 1.],
            [1., 2., 2.]
        ],
    
        [
            [2., 2., 3.],
            [3., 3., 3.]
        ]
    ])
    

    Note:PyTorch has another function view() that does the same thing as the reshape().

    Changing shape by squeezing and unsqueezing

    These functions allow us to expand or shrink the rank (number of dimensions) of our tensor.

    > print(t.reshape([1,12]))
    > print(t.reshape([1,12]).shape)
    tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
    torch.Size([1, 12])
    
    > print(t.reshape([1,12]).squeeze())
    > print(t.reshape([1,12]).squeeze().shape)
    tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
    torch.Size([12])
    
    > print(t.reshape([1,12]).squeeze().unsqueeze(dim=0))
    > print(t.reshape([1,12]).squeeze().unsqueeze(dim=0).shape)
    tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
    torch.Size([1, 12])
    

    Let’s look at a common use case for squeezing a tensor by building a flatten function.

    Flatten a tensor

    Flattening a tensor means to remove all of the dimensions except for one.

    A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements.

    def flatten(t):
        t = t.reshape(1, -1)
        t = t.squeeze()
        return t
    
    > t = torch.ones(4, 3)
    > t
    tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
    
    > flatten(t)
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
    

    We'll see that flatten operations are required when passing an output tensor from a convolutional layer to a linear layer.

    In these examples, we have flattened the entire tensor, however, it is possible to flatten only specific parts of a tensor. For example, suppose we have a tensor of shape [2,1,28,28] for a CNN. This means that we have a batch of 2 grayscale images with height and width dimensions of 28 x 28, respectively.

    Here, we can specifically flatten the two images. To get the following shape: [2,1,784]. We could also squeeze off the channel axes to get the following shape: [2,784].

    Concatenating tensors

    We combine tensors using the cat() function

    > t1 = torch.tensor([
        [1,2],
        [3,4]
    ])
    > t2 = torch.tensor([
        [5,6],
        [7,8]
    ])
    

    We can combine t1 and t2 row-wise (axis-0) in the following way:

    > torch.cat((t1, t2), dim=0)
    tensor([[1, 2],
            [3, 4],
            [5, 6],
            [7, 8]])
    

    We can combine them column-wise (axis-1) like this:

    > torch.cat((t1, t2), dim=1)
    tensor([[1, 2, 5, 6],
            [3, 4, 7, 8]])
    

    Flatten operation for a batch of image inputs to a CNN

    Flattening specific axes of a tensor

    We know that the tensor inputs to a convolutional neural network typically have 4 axes, one for batch size, one for color channels, and one each for height and width.

    [(Batch Size, Channels, Height, Width) ]

    To start, suppose we have the following three tensors.

    t1 = torch.tensor([
        [1,1,1,1],
        [1,1,1,1],
        [1,1,1,1],
        [1,1,1,1]
    ])
    
    t2 = torch.tensor([
        [2,2,2,2],
        [2,2,2,2],
        [2,2,2,2],
        [2,2,2,2]
    ])
    
    t3 = torch.tensor([
        [3,3,3,3],
        [3,3,3,3],
        [3,3,3,3],
        [3,3,3,3]
    ])
    

    Remember, batches are represented using a single tensor, so we’ll need to combine these three tensors into a single larger tensor that has three axes instead of 2.

    > t = torch.stack((t1, t2, t3))
    > t.shape
    
    torch.Size([3, 4, 4])
    

    Here, we used the stack() function to concatenate our sequence of three tensors along a new axis.

    > t
    tensor([[[1, 1, 1, 1],
             [1, 1, 1, 1],
             [1, 1, 1, 1],
             [1, 1, 1, 1]],
    
            [[2, 2, 2, 2],
             [2, 2, 2, 2],
             [2, 2, 2, 2],
             [2, 2, 2, 2]],
    
            [[3, 3, 3, 3],
             [3, 3, 3, 3],
             [3, 3, 3, 3],
             [3, 3, 3, 3]]])
    

    All we need to do now to get this tensor into a form that a CNN expects is add an axis for the color channels. We basically have an implicit single color channel for each of these image tensors, so in practice, these would be grayscale images.

    > t = t.reshape(3,1,4,4)
    > t
    tensor(
    [
        [
            [
                [1, 1, 1, 1],
                [1, 1, 1, 1],
                [1, 1, 1, 1],
                [1, 1, 1, 1]
            ]
        ],
        [
            [
                [2, 2, 2, 2],
                [2, 2, 2, 2],
                [2, 2, 2, 2],
                [2, 2, 2, 2]
            ]
        ],
        [
            [
                [3, 3, 3, 3],
                [3, 3, 3, 3],
                [3, 3, 3, 3],
                [3, 3, 3, 3]
            ]
        ]
    ])
    
    Flattening the tensor batch

    Here are some alternative implementations of the flatten() function.

    > t.reshape(1,-1)[0]
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    
    > t.reshape(-1)
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    
    > t.view(t.numel())
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    
    > t.flatten()
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    

    This flattened batch won’t work well inside our CNN because we need individual predictions for each image within our batch tensor, and now we have a flattened mess.

    The solution here, is to flatten each image while still maintaining the batch axis. This means we want to flatten only part of the tensor. We want to flatten the, color channel axis with the height and width axes.

    > t.flatten(start_dim=1).shape
    torch.Size([3, 16])
    
    > t.flatten(start_dim=1)
    tensor(
    [
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
    ]
    )
    
  • 相关阅读:
    es集群状态
    浅谈GO语言中的面向对象
    jstat命令详解
    jvm g1gc回收器
    解决ES集群状态异常教程(存在UNASSIGNED)
    html5分割上传实现超大文件无插件网页上传工具
    html5分割上传实现超大文件无插件网页上传
    科讯使用的:ckeditor编辑器.复制word图片.一直沾不上去.谁有好的解决办法呢
    编辑器直接word直接上传word里的图片
    请问有支持直接从 word 文档复制图片的 editor 吗
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/11068466.html
Copyright © 2011-2022 走看看