zoukankan      html  css  js  c++  java
  • Reshapeing operations

    Reshapeing operations

    Suppose we have the following tensor:

    t = torch.tensor([
        [1,1,1,1],
        [2,2,2,2],
        [3,3,3,3]
    ], dtype=torch.float32)
    

    We have two ways to get the shape:

    > t.size()
    torch.Size([3, 4])
    
    > t.shape
    torch.Size([3, 4])
    

    The rank of a tensor is equal to the length of the tensor's shape.

    > len(t.shape)
    2
    

    We can also deduce the number of elements contained within the tensor.

    > torch.tensor(t.shape).prod()
    tensor(12)
    

    In PyTorch, there is a dedicated function for this:

    > t.numel()
    12
    

    Reshaping a tensor in PyTorch

    > t.reshape([2,6])
    tensor([[1., 1., 1., 1., 2., 2.],
            [2., 2., 3., 3., 3., 3.]])
    
    > t.reshape([3,4])
    tensor([[1., 1., 1., 1.],
            [2., 2., 2., 2.],
            [3., 3., 3., 3.]])
    
    > t.reshape([4,3])
    tensor([[1., 1., 1.],
            [1., 2., 2.],
            [2., 2., 3.],
            [3., 3., 3.]])
    
    > t.reshape(6,2)
    tensor([[1., 1.],
            [1., 1.],
            [2., 2.],
            [2., 2.],
            [3., 3.],
            [3., 3.]])
    
    > t.reshape(12,1)
    tensor([[1.],
            [1.],
            [1.],
            [1.],
            [2.],
            [2.],
            [2.],
            [2.],
            [3.],
            [3.],
            [3.],
            [3.]])
    

    In this example, we increase the rank to 3 :

    > t.reshape(2,2,3)
    tensor(
    [
        [
            [1., 1., 1.],
            [1., 2., 2.]
        ],
    
        [
            [2., 2., 3.],
            [3., 3., 3.]
        ]
    ])
    

    Note:PyTorch has another function view() that does the same thing as the reshape().

    Changing shape by squeezing and unsqueezing

    These functions allow us to expand or shrink the rank (number of dimensions) of our tensor.

    > print(t.reshape([1,12]))
    > print(t.reshape([1,12]).shape)
    tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
    torch.Size([1, 12])
    
    > print(t.reshape([1,12]).squeeze())
    > print(t.reshape([1,12]).squeeze().shape)
    tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
    torch.Size([12])
    
    > print(t.reshape([1,12]).squeeze().unsqueeze(dim=0))
    > print(t.reshape([1,12]).squeeze().unsqueeze(dim=0).shape)
    tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]])
    torch.Size([1, 12])
    

    Let’s look at a common use case for squeezing a tensor by building a flatten function.

    Flatten a tensor

    Flattening a tensor means to remove all of the dimensions except for one.

    A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements.

    def flatten(t):
        t = t.reshape(1, -1)
        t = t.squeeze()
        return t
    
    > t = torch.ones(4, 3)
    > t
    tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
    
    > flatten(t)
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
    

    We'll see that flatten operations are required when passing an output tensor from a convolutional layer to a linear layer.

    In these examples, we have flattened the entire tensor, however, it is possible to flatten only specific parts of a tensor. For example, suppose we have a tensor of shape [2,1,28,28] for a CNN. This means that we have a batch of 2 grayscale images with height and width dimensions of 28 x 28, respectively.

    Here, we can specifically flatten the two images. To get the following shape: [2,1,784]. We could also squeeze off the channel axes to get the following shape: [2,784].

    Concatenating tensors

    We combine tensors using the cat() function

    > t1 = torch.tensor([
        [1,2],
        [3,4]
    ])
    > t2 = torch.tensor([
        [5,6],
        [7,8]
    ])
    

    We can combine t1 and t2 row-wise (axis-0) in the following way:

    > torch.cat((t1, t2), dim=0)
    tensor([[1, 2],
            [3, 4],
            [5, 6],
            [7, 8]])
    

    We can combine them column-wise (axis-1) like this:

    > torch.cat((t1, t2), dim=1)
    tensor([[1, 2, 5, 6],
            [3, 4, 7, 8]])
    

    Flatten operation for a batch of image inputs to a CNN

    Flattening specific axes of a tensor

    We know that the tensor inputs to a convolutional neural network typically have 4 axes, one for batch size, one for color channels, and one each for height and width.

    [(Batch Size, Channels, Height, Width) ]

    To start, suppose we have the following three tensors.

    t1 = torch.tensor([
        [1,1,1,1],
        [1,1,1,1],
        [1,1,1,1],
        [1,1,1,1]
    ])
    
    t2 = torch.tensor([
        [2,2,2,2],
        [2,2,2,2],
        [2,2,2,2],
        [2,2,2,2]
    ])
    
    t3 = torch.tensor([
        [3,3,3,3],
        [3,3,3,3],
        [3,3,3,3],
        [3,3,3,3]
    ])
    

    Remember, batches are represented using a single tensor, so we’ll need to combine these three tensors into a single larger tensor that has three axes instead of 2.

    > t = torch.stack((t1, t2, t3))
    > t.shape
    
    torch.Size([3, 4, 4])
    

    Here, we used the stack() function to concatenate our sequence of three tensors along a new axis.

    > t
    tensor([[[1, 1, 1, 1],
             [1, 1, 1, 1],
             [1, 1, 1, 1],
             [1, 1, 1, 1]],
    
            [[2, 2, 2, 2],
             [2, 2, 2, 2],
             [2, 2, 2, 2],
             [2, 2, 2, 2]],
    
            [[3, 3, 3, 3],
             [3, 3, 3, 3],
             [3, 3, 3, 3],
             [3, 3, 3, 3]]])
    

    All we need to do now to get this tensor into a form that a CNN expects is add an axis for the color channels. We basically have an implicit single color channel for each of these image tensors, so in practice, these would be grayscale images.

    > t = t.reshape(3,1,4,4)
    > t
    tensor(
    [
        [
            [
                [1, 1, 1, 1],
                [1, 1, 1, 1],
                [1, 1, 1, 1],
                [1, 1, 1, 1]
            ]
        ],
        [
            [
                [2, 2, 2, 2],
                [2, 2, 2, 2],
                [2, 2, 2, 2],
                [2, 2, 2, 2]
            ]
        ],
        [
            [
                [3, 3, 3, 3],
                [3, 3, 3, 3],
                [3, 3, 3, 3],
                [3, 3, 3, 3]
            ]
        ]
    ])
    
    Flattening the tensor batch

    Here are some alternative implementations of the flatten() function.

    > t.reshape(1,-1)[0]
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    
    > t.reshape(-1)
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    
    > t.view(t.numel())
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    
    > t.flatten()
    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    

    This flattened batch won’t work well inside our CNN because we need individual predictions for each image within our batch tensor, and now we have a flattened mess.

    The solution here, is to flatten each image while still maintaining the batch axis. This means we want to flatten only part of the tensor. We want to flatten the, color channel axis with the height and width axes.

    > t.flatten(start_dim=1).shape
    torch.Size([3, 16])
    
    > t.flatten(start_dim=1)
    tensor(
    [
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
    ]
    )
    
  • 相关阅读:
    [Oracle]快速生成大量模拟数据的方法
    [Oracle]Oracle Fail Safe 与 SQLNET.AUTHENTICATION_SERVICES关系
    [Oracle]构筑TDE 环境的例子
    [Oracle]包含了MVIEW的表领域,在进行导出,表领域改名,再导入后,MVIEW会消失不见。
    [python][spark]wholeTextFiles 读入多个文件的例子
    [Spark][Python]RDD flatMap 操作例子
    Android 仿知乎创意广告
    移动端强大的富文本编辑器richeditor-android
    Python-Flask实现电影系统管理后台
    俄罗斯方块-C语言-详注版
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/11068466.html
Copyright © 2011-2022 走看看