zoukankan      html  css  js  c++  java
  • einsum爱因斯坦求和

    最近因为vision transformer里的pytorch代码,看到了torch.einsum(np.einsum同理)这个操作,简直是神了;

    比如

    t = torch.randn(2,4,3)
    q, k, v = tuple(rearrange(t, 'b t (d k) -> k b t d ', k=3))
    print(q,'
    ',k)
    
    >>>
    tensor([[[-0.9011],
             [-0.2627],
             [ 0.4202],
             [-0.3396]],
    
            [[ 0.0530],
             [ 0.5980],
             [ 0.1464],
             [ 0.7939]]]) 
     tensor([[[-1.0567],
             [ 0.0425],
             [-0.2160],
             [-2.2235]],
    
            [[ 0.3932],
             [-0.5011],
             [ 0.0748],
             [-1.3025]]])

    可以看到这里生成了transformer里的q,k,v,维度是(2,4,1),维度含义分别是 (batch_size, token,dim),然后要做一个q*k^T的向量外积

    scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k)
    
    scaled_dot_prod
    >>>tensor([[[ 0.9523, -0.0383,  0.1947,  2.0037],
             [ 0.2776, -0.0112,  0.0568,  0.5842],
             [-0.4440,  0.0179, -0.0908, -0.9342],
             [ 0.3588, -0.0144,  0.0734,  0.7551]],
    
            [[ 0.0208, -0.0265,  0.0040, -0.0690],
             [ 0.2351, -0.2996,  0.0447, -0.7789],
             [ 0.0575, -0.0733,  0.0109, -0.1906],
             [ 0.3122, -0.3978,  0.0594, -1.0341]]])

    注意,这里的q和k都是同一维度,不用像原来做矩阵乘法那样要维度对应,而是可以直接指定维度去对应地乘;

    因此,这里把k换到(2,1,4)的维度然后去和q乘,也是可以的,例如:

    k_ = rearrange(k,'b t d -> b d t')
    k_
    a_scaled_dot_prod = torch.einsum('b i d , b d j -> b i j', q, k_)
    
    a_scaled_dot_prod
    
    >>>
    tensor([[[ 0.9523, -0.0383,  0.1947,  2.0037],
             [ 0.2776, -0.0112,  0.0568,  0.5842],
             [-0.4440,  0.0179, -0.0908, -0.9342],
             [ 0.3588, -0.0144,  0.0734,  0.7551]],
    
            [[ 0.0208, -0.0265,  0.0040, -0.0690],
             [ 0.2351, -0.2996,  0.0447, -0.7789],
             [ 0.0575, -0.0733,  0.0109, -0.1906],
             [ 0.3122, -0.3978,  0.0594, -1.0341]]])

     参考:https://zhuanlan.zhihu.com/p/74462893

    人生苦短,何不用python
  • 相关阅读:
    队列与栈的综合实现
    枚举属性和不可枚举属性
    Ajax状态值及状态码
    jquery版滑块导航栏
    js版面向对象图片放大镜
    jq封装淘宝图片轮播插件
    前端必备的js知识点(转载)
    如何有效地解决ie7,IE8不支持document.getElmentsByClassName的问题
    mysql的基本命令行操作
    jquery版楼层滚动特效
  • 原文地址:https://www.cnblogs.com/yqpy/p/14479749.html
Copyright © 2011-2022 走看看