zoukankan      html  css  js  c++  java
  • 工作小结三

    torch.max()输入两个tensor

    RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

    最近看源代码时候没看懂骚操作

    def find_intersection(set_1, set_2):
        """
        Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.
    
        :param set_1: set 1, a tensor of dimensions (n1, 4)
        :param set_2: set 2, a tensor of dimensions (n2, 4)
        :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
        """
    
        # PyTorch auto-broadcasts singleton dimensions
        lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
        upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
        intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
        return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)
    
    

    那里说求交集应该是两个边界X距离--两个框的宽度乘以两个边界Y距离--两个框的宽度即可

    原来问题出在torch.max()上,简单的用法这里不再赘述,仅仅看最后一个用法,pytorch官方也是一笔带过

    torch.max(input, other, out=None) → Tensor
    Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise maximum is taken.
    
    The shapes of input and other don’t need to match, but they must be broadcastable.
    
    	ext{out}_i = max(	ext{tensor}_i, 	ext{other}_i)
    out_i=max( tensor_i,other_i )
    NOTE
    
    When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.
    
    Parameters
    input (Tensor) – the input tensor.
    
    other (Tensor) – the second input tensor
    
    out (Tensor, optional) – the output tensor.
    
    Example:
    
    >>> a = torch.randn(4)
    >>> a
    tensor([ 0.2942, -0.7416,  0.2653, -0.1584])
    >>> b = torch.randn(4)
    >>> b
    tensor([ 0.8722, -1.7421, -0.4141, -0.5055])
    >>> torch.max(a, b)
    tensor([ 0.8722, -0.7416,  0.2653, -0.1584])
    

    正常如果如初两个shape相同的tensor,直接按元素比较即可

    如果两个不同的tensor上面官方没有说明:

    这里举个例子:输入aaa=2 * 2,bbb=2 * 3

    aaa = torch.randn(2,2)
    bbb = torch.randn(3,2)
    ccc = torch.max(aaa,bbb)
    RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
    

    出现以上的错误,这里先进行分析一下:

    2 * 2 3 * 2无法直接进行比较,按照pytorch官方的说法逐元素比较,那么输出也就应该是2 * 3 * 2,我们进一步进行测试:

    aaa = torch.randn(1,2)
    bbb = torch.randn(3,2)
    ccc = torch.max(aaa,bbb)
    tensor([[1.0350, 0.2532],
            [0.2203, 0.2532],
            [0.2912, 0.2532]])
    

    直接可以输出,不会报错

    原来pytorch的原则是这样的:维度不同只能比较一维的数据

    那么我们可以进一步测试,将输入的2 * 23 * 2转换成1 * 2 * 23 * 1 * 2

    aaa = torch.randn(2,2).unsqueeze(1)
    bbb = torch.randn(3,2).unsqueeze(0)
    ccc = torch.max(aaa,bbb)
    RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
    

    好了,问题完美解决!有时间去看一下源代码怎么实现的,咋不智能。。。。

  • 相关阅读:
    CodeForces 710CMagic Odd Square(经典-奇数个奇数&偶数个偶数)
    CodeForces 710A King Moves(水题-越界问题)
    CodeForces 701C They Are Everywhere (滑动窗口)
    CodeForces 701B Cells Not Under Attack
    [补档]happiness
    [补档]王者之剑
    [补档]士兵占领
    [补档]搭配飞行员
    [补档]暑假集训D6总结
    [补档][Lydsy2017年4月月赛]抵制克苏恩
  • 原文地址:https://www.cnblogs.com/wjy-lulu/p/11878195.html
Copyright © 2011-2022 走看看