zoukankan      html  css  js  c++  java
  • pytorch中参数dim的含义(正负,零,不传)

    总结:

    torch.function(x, dim)

    1.if 不传: 依照默认参数决定

    2.if dim >=0 and dim <= x.dim()-1: 0是沿最粗数据粒度的方向进行操作,x.dim()-1是按最细粒度的方向。

    3.if dim <0: dim的最小取值(此按照不同function而定)到最大取值(-1)之间。与情况2正好相反,最大的取值(-1)代表按最细粒度的方向,最小的取值按最粗粒度的方向。

    实验代码:(使用torch.max(x, dim)为例子)

    1.dim=2

    m
    Out[77]:
    tensor([[1, 2, 3],
            [4, 5, 6]])

    torch.max(m,)
    Out[85]: tensor(6)

    不传:默认参数的设定是对整个传入的数据进行操作


    torch.max(m, dim=0)
    Out[79]:
    torch.return_types.max(
    values=tensor([4, 5, 6]),
    indices=tensor([1, 1, 1]))

    此处最粗粒度是两行之间[1, 2, 3]->[4, 5, 6]的方向,也就是常说是纵向进行操作。


    torch.max(m, dim=1)
    Out[78]:
    torch.return_types.max(
    values=tensor([3, 6]),
    indices=tensor([2, 2]))

    此处最细粒度是一行之内[1, 2, 3]的方向,也就是常说是横向进行操作。

    torch.max(m, dim=2)
    Traceback (most recent call last):
      File "/home/xutianfan/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
        exec(code_obj, self.user_global_ns, self.user_ns)
      File "<ipython-input-84-ce6440fe62e4>", line 1, in <module>
        torch.max(m, dim=2)
    IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

    torch.max(m, dim=-1)
    Out[86]:
    torch.return_types.max(
    values=tensor([3, 6]),
    indices=tensor([2, 2]))

    -1+2=1,同torch.max(m, dim=1)结果。

    torch.max(m, dim=-2)
    Out[87]:
    torch.return_types.max(
    values=tensor([4, 5, 6]),
    indices=tensor([1, 1, 1]))

    2.dim=3(tensor)

    t1
    Out[89]:
    tensor([[[0, 1, 2, 3],
             [1, 2, 3, 4]],
            [[2, 3, 4, 5],
             [4, 5, 6, 7]],
            [[5, 6, 7, 8],
             [6, 7, 8, 9]]])

    torch.max(t1)
    Out[94]: tensor(9)

    torch.max(t1, dim=0)
    Out[91]:
    torch.return_types.max(
    values=tensor([[5, 6, 7, 8],
            [6, 7, 8, 9]]),
    indices=tensor([[2, 2, 2, 2],
            [2, 2, 2, 2]]))

    最粗粒度是在各个矩阵之间的方向,所以对各个矩阵的每个位置分别取最大。

    torch.max(t1, dim=1)
    Out[92]:
    torch.return_types.max(
    values=tensor([[1, 2, 3, 4],
            [4, 5, 6, 7],
            [6, 7, 8, 9]]),
    indices=tensor([[1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1]]))

    其次粗的粒度是矩阵中各行之间的方向

    torch.max(t1, dim=2)
    Out[93]:
    torch.return_types.max(
    values=tensor([[3, 4],
            [5, 7],
            [8, 9]]),
    indices=tensor([[3, 3],
            [3, 3],
            [3, 3]]))
    最细粒度是各行之内的方向。所以取出了各行中最大的元素。

    torch.max(t1, dim=-1)
    Out[97]:
    torch.return_types.max(
    values=tensor([[3, 4],
            [5, 7],
            [8, 9]]),
    indices=tensor([[3, 3],
            [3, 3],
            [3, 3]]))

    虽然我们这里只使用了max函数,但是这对于torch中其他函数(例如softmax)也有效。

    可以有这种写法:mean = x.mean(-1, keepdim=True)

    这样无论是对于2维还是3维的输入,都自动dim=input.dim()-1,也就是从最细粒度取平均。

  • 相关阅读:
    iTerm2使用技巧
    我的mac下有关php扩展的安装
    xmlhttprequest 1.0和2.0的区别,from qq前端哥
    PHP错误日志记录:display_errors与log_errors的区别
    目前php连接mysql的主要方式
    闭包介绍汇总
    接口设计知识总结
    git命令——从GitHub clone XXX分支,本地创建新分支push到远程仓库
    Spring错误——Junit测试——java.net.BindException: Address already in use: bind
    Java.util.Random生成随机数
  • 原文地址:https://www.cnblogs.com/gagaein/p/14366274.html
Copyright © 2011-2022 走看看