zoukankan      html  css  js  c++  java
  • Pytorch 计算两个张量的欧式距离

    1.Pytorch计算公式

    a,b为两个张量,且a.size=(B,N,3),b.size()=(B,M,3),计算a中各点到b中各点的距离,返回距离张量c,c.size()=(B,N,M)。不考虑Batch时,可以将理解:c的第i行j列的值表示a中第i个点到b中第j个点的距离。

    import torch
    
    def EuclideanDistance(t1,t2):
        dim=len(t1.size())
        if dim==2:
            N,C=t1.size()
            M,_=t2.size()
            dist = -2 * torch.matmul(t1, t2.permute(1, 0))
            dist += torch.sum(t1 ** 2, -1).view(N, 1)
            dist += torch.sum(t2 ** 2, -1).view(1, M)
            dist=torch.sqrt(dist)
            return dist
        elif dim==3:
            B,N,_=t1.size()
            _,M,_=t2.size()
            dist = -2 * torch.matmul(t1, t2.permute(0, 2, 1))
            dist += torch.sum(t1 ** 2, -1).view(B, N, 1)
            dist += torch.sum(t2 ** 2, -1).view(B, 1, M)
            dist=torch.sqrt(dist)
            return dist
        else:
            print('error...')
    
    print(f'dimensional 2.......')
    a=torch.Tensor([[0,0],[1,1]])
    b=torch.Tensor([[1,0],[3,4]])
    print(f'size of a:{a.size()}\tsize of b:{b.size()}')
    print(f'distance of point a and b is: {EuclideanDistance(a,b)}')
    print(f'\ndimensional 3.......')
    a=torch.unsqueeze(a,dim=0)
    b=torch.unsqueeze(b,dim=0)
    print(f'size of a:{a.size()}\tsize of b:{b.size()}')
    print(f'distance of point a and b is: {EuclideanDistance(a,b)}')

    2.代码理解

    2.1定义待计算张量

    现有张量a,b如下:

    2.2距离公式

    有距离公式如下:

    2.3分步计算

    (1)计算:

    d1=-2 * torch.matmul(a, b.permute(0, 2, 1))

    (1)结果如下:

    (2)计算:

    d2=torch.sum(a** 2, -1)
    d3=torch.sum(b** 2, -1)

    (2)结果如下:

    当前有:d1.size=(B,N,M),d2.size()=(B,N,1),d3.size()=(B,M,1)

    可以看到d1中的i行中保持不变的部分为a中的第i个点,d1中第j列中不变的部分对应b中的j行。因此,只需在d1的行上加上一个d2的对应行,列上加d3的对应行即可。

    (3)相加:

    d=d1+d2.view(B,N,1)+d3.view(B,1,M)

    (3)结果如下:

     (4)开根

    d=torch.sqrt(d)
  • 相关阅读:
    POJ 3660 Cow Contest (floyd求联通关系)
    POJ 3660 Cow Contest (最短路dijkstra)
    POJ 1860 Currency Exchange (bellman-ford判负环)
    POJ 3268 Silver Cow Party (最短路dijkstra)
    POJ 1679 The Unique MST (最小生成树)
    POJ 3026 Borg Maze (最小生成树)
    HDU 4891 The Great Pan (模拟)
    HDU 4950 Monster (水题)
    URAL 2040 Palindromes and Super Abilities 2 (回文自动机)
    URAL 2037 Richness of binary words (回文子串,找规律)
  • 原文地址:https://www.cnblogs.com/waterbbro/p/15580506.html
Copyright © 2011-2022 走看看