zoukankan      html  css  js  c++  java
  • Pytorch 计算两个张量的欧式距离

    1.Pytorch计算公式

    a,b为两个张量,且a.size=(B,N,3),b.size()=(B,M,3),计算a中各点到b中各点的距离,返回距离张量c,c.size()=(B,N,M)。不考虑Batch时,可以将理解:c的第i行j列的值表示a中第i个点到b中第j个点的距离。

    import torch
    
    def EuclideanDistance(t1,t2):
        dim=len(t1.size())
        if dim==2:
            N,C=t1.size()
            M,_=t2.size()
            dist = -2 * torch.matmul(t1, t2.permute(1, 0))
            dist += torch.sum(t1 ** 2, -1).view(N, 1)
            dist += torch.sum(t2 ** 2, -1).view(1, M)
            dist=torch.sqrt(dist)
            return dist
        elif dim==3:
            B,N,_=t1.size()
            _,M,_=t2.size()
            dist = -2 * torch.matmul(t1, t2.permute(0, 2, 1))
            dist += torch.sum(t1 ** 2, -1).view(B, N, 1)
            dist += torch.sum(t2 ** 2, -1).view(B, 1, M)
            dist=torch.sqrt(dist)
            return dist
        else:
            print('error...')
    
    print(f'dimensional 2.......')
    a=torch.Tensor([[0,0],[1,1]])
    b=torch.Tensor([[1,0],[3,4]])
    print(f'size of a:{a.size()}\tsize of b:{b.size()}')
    print(f'distance of point a and b is: {EuclideanDistance(a,b)}')
    print(f'\ndimensional 3.......')
    a=torch.unsqueeze(a,dim=0)
    b=torch.unsqueeze(b,dim=0)
    print(f'size of a:{a.size()}\tsize of b:{b.size()}')
    print(f'distance of point a and b is: {EuclideanDistance(a,b)}')

    2.代码理解

    2.1定义待计算张量

    现有张量a,b如下:

    2.2距离公式

    有距离公式如下:

    2.3分步计算

    (1)计算:

    d1=-2 * torch.matmul(a, b.permute(0, 2, 1))

    (1)结果如下:

    (2)计算:

    d2=torch.sum(a** 2, -1)
    d3=torch.sum(b** 2, -1)

    (2)结果如下:

    当前有:d1.size=(B,N,M),d2.size()=(B,N,1),d3.size()=(B,M,1)

    可以看到d1中的i行中保持不变的部分为a中的第i个点,d1中第j列中不变的部分对应b中的j行。因此,只需在d1的行上加上一个d2的对应行,列上加d3的对应行即可。

    (3)相加:

    d=d1+d2.view(B,N,1)+d3.view(B,1,M)

    (3)结果如下:

     (4)开根

    d=torch.sqrt(d)
  • 相关阅读:
    axios 修改头部请求数据格式的方法
    基于VUE的可以滚动的横向时间轴
    25.客户端多线程分组模拟高频并发数据
    24.原子操作
    23.线程锁的使用
    22.线程自解锁
    21.多线程-锁与临界区域
    20.多线程-基本代码
    19.添加高精度计时器测量处理能力
    18.windows使用select突破64个socket
  • 原文地址:https://www.cnblogs.com/waterbbro/p/15580506.html
Copyright © 2011-2022 走看看