zoukankan      html  css  js  c++  java
  • pytorch-tensor处理速查表(cat stack squeeze unsqueeze permute等)

    1 torch.cat

    torch.cat((A, B), dim)
    

    将两个tensor在指定维度进行拼接

        A = torch.zeros(2,3)
        B = torch.zeros(2,3)
        C = torch.cat((A,B), 0) ## shape [4,3]
        D = torch.cat((A,B), 1) ## shape [2,6]
    

    2 torch.stack

    torch.stack((A, B), dim)
    

    增加新的维度进行堆叠

    A = torch.zeros(1,3)
    B = torch.zeros(1,3)
    C = torch.stack((A,B), 0)  ## [2, 1, 3]
    D = torch.stack((A,B), 1)  ## [1, 2, 3]
    E = torch.stack((A,B), 2)  ## [1, 3, 2]
    

    3 torch.permute

    A = A.permute(0, 2, 3, 1)
    

    调整tensor的维度顺序,相当于更灵活的transpose

    A = torch.zeros(32, 3, 18, 18)  ## [32, 3, 18, 18]
    B = A.permute(0, 2, 3, 1)          ##[32, 18, 18, 3]
    

    4 tensor.contiguous
    view只能用在contiguous的tensor上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。
    eg:

    v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
    

    5 tensor.squeeze

    A = A.squeeze(dim)
    

    去掉tensor的维度为1的维度,该维度可以通过参数dim指定,也可以不加参数,默认找到维度为1的维度然后去掉

    A = torch.zeros(1, 18, 18)  ## [1, 18, 18]
    B = A.squeeze(0)               ## [18, 18]
    

    6 tensor.unsqueeze

    A = A.unsqueee(dim)
    

    在tensor中增加一个新的指定维度,新维度放在指定位置 原来维度序列向两边移动

    A = torch.zeros(2, 3, 4)   ## [2, 3, 4]
    B = A.unsqueeze(0)    ## [1, 2, 3, 4]
    C = A.unsqueeze(1)    ## [2, 1, 3, 4]      
    D = A.unsqueeze(2)    ## [2, 3, 1, 4]
    E = A.unsqueeze(3)    ## [2, 3, 4, 1]
    

    7 tensor.expand

    A = A.expand()
    

    在指定维度上扩展数据, 该指定维度长度为1,否则报错。(此时扩展仅是创建新的视图,并不进行数据复制)

    A = torch.zeros(2, 3, 1) ## [2, 3, 1]
    B = A.expand(2, 3, 3)   ## [2, 3,  3]
    

    8 tensor.clone()
    clone() 得到的tensor不仅拷贝了原始的value,而且会计算梯度传播信息

    b = a.clone()
    

    9 tensor.copy_(src_tensor)
    只拷贝src_tensor的数据到dst_tensor上,并返回self

    a = torch.ones([3,4])
    b = torch.zeros([3,4])
    b.copy_(a)
    

    10 生成特定尺度、特定数值的tensor

    a = torch.Tensor(3,5).fill_(0)
    a = torch.full((3,5), 0, dtype=torch.IntTensor)
    
    如果有一天我们淹没在茫茫人海中庸碌一生,那一定是我们没有努力活得丰盛
  • 相关阅读:
    Linux下PHP安装配置MongoDB数据库连接扩展
    Linux下安装配置MongoDB数据库
    解决VMWARE 虚拟机安装64位系统“此主机支持 Intel VT-x,但 Intel VT-x 处于禁用状态
    nginx配置多域名
    nginx File not found 错误
    RunLoop与NSTimer的经典面试题
    子线程上的RunLoop运行循环
    主线程上的RunLoop运行循环
    RunLoop运行循环/消息循环
    自动释放池和运行/消息循环
  • 原文地址:https://www.cnblogs.com/yeran/p/11113926.html
Copyright © 2011-2022 走看看