zoukankan      html  css  js  c++  java
  • torch Tensor学习:切片操作

    torch Tensor学习:切片操作

    一直使用的是matlab处理矩阵,想从matlab转到lua+torch上,然而在matrix处理上遇到了好多类型不匹配问题。所以这里主要总结一下torch/Tensor中切片操作方法以及其参数类型,以备查询。

    已知有矩阵M

    M=torch.range(1,20):resize(4,5)
    
    th> M
      1   2   3   4   5
      6   7   8   9  10
     11  12  13  14  15
     16  17  18  19  20
    [torch.DoubleTensor of size 4x5]
    

    这里需要注意的是 resize() 和reshape()的区别, M:reshape(5,4)并不改变M,但是M:resize(5,4)改变M的size

    1. 选取M的第2,3列
    N1=M:narrow(2,2,2)
    -- [self] narrow (dim,index,size)   --> dim表示待选取的维度,index 是待选取连续列的起始数值, size是选取的列数,那么总体就是在第二维上从第2列选取连续的2列
    
    N2=M:sub(1,-1,2,3)
    -- [Tensor] sub(dim1s,dim1e ...[,dim4s [,dim4e]])
    --dim1s,dim1e 分别对应第一维上的起始index和终止index,-1表示该维度倒数第一个...以此类推
    
    N3=M[{{},{2,3}}]
    

    另外
    [Tensor] select(dim,index) :在第dim维选取第index"行",这只能取一个切片

    需要注意的几点:
    A. sub, narrow, select, [] 函数都是在原始的数据上进行操作的,也就是说获得的Tensor都仅仅是一个视图而已,改变其中一个另一个也会变化
    B. 因为A的原因,这几个函数执行的速度非常快(对于这点深有体会,torch中Tensor操作速度蛮快的,但是涉及到内存分配,速度有点慢)
    C. sub,narrow,select 都是选取Tensor一块数据,而不能选取特定的值,比如一次性选取上述矩阵M的第1行第2列,第2行第4列 和第4行第1列的值
    C. [Tensor][{dim1,dim2}] or [{{dim1s,dim1e},{dim2s,dim2e}}] 的输入量还可以是ByteTensor类型,由0,1元素组成的掩码矩阵,比如

    M[torch.lt(M,8)] =0   --令M中小于8的元素为0
    
    th> M
      0   0   0   0   0
      0   0   8   9  10
     11  12  13  14  15
     16  17  18  19  20
    [torch.DoubleTensor of size 4x5]
    

    注: 逻辑操作:lt, gt, le, eq, ge,ne返回的都是ByteTensor类型的掩码Tensor

    1. 针对于上边提到的注意点C,如果想选取第dim维上的某几个不连续的“行”
      例如 选取M的第2,5,3,1列构成新的矩阵
    N1=M:index(2,torch.LongTensor({2,5,3,1}))
    -- [Tensor] index(dim,indices) 在矩阵M中选取dim维上索引indices中对应的"行",indices的类型要求为LongTensor
    -- 返回Tensor的number of dimensions 和原始Tensor相同,返回的是新分配的内存
    

    注: [] 操作是一系列narrow,select,sub的组合,其并不涉及新内存,而index则涉及到新内存的分配
    index的相关函数
    a. indexCopy

    -- [Tensor] indexCopy(dim,indices,tensor) -- 将tensor中的元素拷贝到原tensor对应indices上去,tensor和带存储的大小应该严格一致
    N=torch.Tensor(4,2):fill(-1)
    M:indexCopy(2,torch.LongTensor{4,1},N)   -- 这个返回的是N,但M变化了
    
    th> M
     -1   2   3  -1   5
     -1   7   8  -1  10
     -1  12  13  -1  15
     -1  17  18  -1  20
    [torch.DoubleTensor of size 4x5]
    

    b. indexAdd

    -- [Tensor] indexAdd(dim,indices,tensor) 和indexCopy类似,只是在原tensor的indices对应的位置加上tensor
    M:indexAdd(2,torch.LongTensor{1,3},-N) --返回的是-N
    
    th> M
      0   2   4  -1   5
      0   7   9  -1  10
      0  12  14  -1  15
      0  17  19  -1  20
    [torch.DoubleTensor of size 4x5]
    

    c. indexFill
    [Tensor] indexFill(dim, indices,var) -- 和indexCopy相同,只是使用var去填充indices对应的元素,在原Tensor内存上改变

    总结:
    A. index相关的函数中只有index是重新开辟的内存,而indexCopy,indexFill,indexAdd均是在原内存上操作
    B. index相关函数仅仅能在某一个特定维度dim上,相对自由的选取indices,而不能同时操作多个维度
    C. index相关函数中 indices参数的类型为LongTensor!!!尤其要注意

    1. tensor中元素的自由选取和赋值
      a. gather
      [Tensor] gather (dim,indices) -- 首先这个函数需要重新分配内存
      。-- 该函数的功能主要是沿着dim维度,在每一个row上按照indices选取数值,indices为LongTensor类型
      看下面官方文档的图示更好理解

    enter description here

    1486728865883.jpg

    左图是: result=src:gatter(1,index), index=torch.LongTensor({{1,2,3},{2,3,1}})
    其输出result的大小和index相同,result[{1,{}}]为src每一列上的index[{{1},{}}]对应的元素,为什么是列呢,因为dim=1,决定了沿着row数,也就是列了。 后面的类似

    注意:因为dim=1,所以index的第二维长度应该和src的dim=2长度相同

    b. scatter
    [Tensor] scatter(dim,indices,src|var)
    这个函数和gatter是一组的,gather是取元素,scatter是元素赋值,其indices查找方式相同。待赋值可以使Tensor src也可以是标量var
    同样有新的内存分配发生

    c. maskedSelect
    [Tensor] maskedSelect(mask)
    。-- mask 是ByteTensor类型的掩码矩阵或者向量,元素为0或1. mask并不要求size和src相同,但元素个数必须相同。
    。--返回的是mask中元素1对应的src中元素,长度和mask中1的个数相同,元素类型和src类型相同,ndim=1

    d. maskedCopy
    [Tensor] maskedCopy(mask,tensor)
    。--和maskedSelect的关系,就和index与indexCopy的关系,对掩码确定的元素进行赋值

    e. maskedFill
    [Tensor] maskedFill(mask,val)
    。 --类比index和indexFill

    1. 总结
      A: mask都是ByteTensor类型,indices都是LongTensor类型
      B: view操作的有 narrow, sub,select 和 [ ]
      C: 重新分配内存的有 index,gatter,maskedSelect
      D: 不牵涉到内存的重新分配能够大大提升程序的效率
  • 相关阅读:
    java时间戳转换日期 和 日期转换时间戳
    通过blob文件导出下载成Excel文件
    三元表达式进化
    Vue切换组件实现返回后不重置数据,保留历史设置操作
    vue 下载文件
    ide打断点,跑到某一行代码,再执行的方法
    Java操作终端的方法
    前端下载本地文件的方法
    java 读取本地json文件
    js 时间戳转换
  • 原文地址:https://www.cnblogs.com/YiXiaoZhou/p/6387769.html
Copyright © 2011-2022 走看看