zoukankan      html  css  js  c++  java
  • Tensor索引操作

     
    1. #Tensor索引操作  
    2.     ''''' 
    3.     Tensor支持与numpy.ndarray类似的索引操作,语法上也类似 
    4.     如无特殊说明,索引出来的结果与原tensor共享内存,即修改一个,另一个会跟着修改 
    5.     '''  
    6.     import torch as t  
    7.       
    8.     a = t.randn(3,4)  
    9.     '''''tensor([[ 0.1986,  0.1809,  1.4662,  0.6693], 
    10.             [-0.8837, -0.0196, -1.0380,  0.2927], 
    11.             [-1.1032, -0.2637, -1.4972,  1.8135]])'''  
    12.     print(a[0])         #第0行  
    13.     '''''tensor([0.1986, 0.1809, 1.4662, 0.6693])'''  
    14.     print(a[:,0])       #第0列  
    15.     '''''tensor([ 0.1986, -0.8837, -1.1032])'''  
    16.     print(a[0][2])      #第0行第2个元素,等价于a[0,2]  
    17.     '''''tensor(1.4662)'''  
    18.     print(a[0][-1])     #第0行最后一个元素  
    19.     '''''tensor(0.6693)'''  
    20.     print(a[:2,0:2])    #前两行,第0,1列  
    21.     '''''tensor([[ 0.1986,  0.1809], 
    22.             [-0.8837, -0.0196]])'''  
    23.       
    24.     print(a[0:1,:2])    #第0行,前两列  
    25.     '''''tensor([[0.1986, 0.1809]])'''  
    26.     print(a[0,:2])      #注意两者的区别,形状不同  
    27.     '''''tensor([0.1986, 0.1809])'''  
    28.       
    29.     print(a>1)  
    30.     '''''tensor([[0, 0, 1, 0], 
    31.             [0, 0, 0, 0], 
    32.             [0, 0, 0, 1]], dtype=torch.uint8)'''  
    33.     print(a[a>1])        #等价于a.masked_select(a>1),选择结果与原tensor不共享内存空间  
    34.     print(a.masked_select(a>1))  
    35.     '''''tensor([1.4662, 1.8135]) 
    36.     tensor([1.4662, 1.8135])'''  
    37.     print(a[t.LongTensor([0,1])])  
    38.     '''''tensor([[ 0.1986,  0.1809,  1.4662,  0.6693], 
    39.             [-0.8837, -0.0196, -1.0380,  0.2927]])'''  
    40.       
    41.     ''''' 
    42.                             常用的选择函数 
    43.     index_select(input,dim,index)   在指定维度dim上选取,列如选择某些列、某些行 
    44.     masked_select(input,mask)       例子如上,a[a>0],使用ByteTensor进行选取 
    45.     non_zero(input)                 非0元素的下标 
    46.     gather(input,dim,index)         根据index,在dim维度上选取数据,输出size与index一样 
    47.     gather是一个比较复杂的操作,对一个二维tensor,输出的每个元素如下: 
    48.         out[i][j] = input[index[i][j]][j]   #dim = 0 
    49.         out[i][j] = input[i][index[i][j]]   #dim = 1 
    50.     '''  
    51.       
    52.     b = t.arange(0,16).view(4,4)  
    53.     '''''tensor([[ 0,  1,  2,  3], 
    54.             [ 4,  5,  6,  7], 
    55.             [ 8,  9, 10, 11], 
    56.             [12, 13, 14, 15]])'''  
    57.     index = t.LongTensor([[0,1,2,3]])  
    58.     print(b.gather(0,index))            #取对角线元素  
    59.     '''''tensor([[ 0,  5, 10, 15]])'''  
    60.       
    61.     index = t.LongTensor([[3,2,1,0]]).t()       #取反对角线上的元素  
    62.     print(b.gather(1,index))  
    63.     '''''tensor([[ 3], 
    64.             [ 6], 
    65.             [ 9], 
    66.             [12]])'''  
    67.       
    68.     index = t.LongTensor([[3,2,1,0]])           #取反对角线的元素,与上面不同  
    69.     print(b.gather(0,index))  
    70.     '''''tensor([[12,  9,  6,  3]])'''  
    71.       
    72.     index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()  
    73.     print(b.gather(1,index))  
    74.     '''''tensor([[ 0,  3], 
    75.             [ 5,  6], 
    76.             [10,  9], 
    77.             [15, 12]])'''  
    78.       
    79.     ''''' 
    80.     与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而 
    81.     scatter_是把取出的数据再放回去,scatter_函数时inplace操作 
    82.     out = input.gather(dim,index) 
    83.     out = Tensor() 
    84.     out.scatter_(dim,index) 
    85.     '''  
    86.       
    87.     x = t.rand(2, 5)  
    88.     print(x)  
    89.     c = t.zeros(3, 5).scatter_(0, t.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)  
    90.     print(c)  
    91. 2018-10-23 20:30:30       
    Monkey
  • 相关阅读:
    从csdn转移到博客园的一篇测试文章
    接口与抽象类的区别
    python网络爬虫进阶之HTTP原理,爬虫的基本原理,Cookies和代理介绍
    python验证码识别(2)极验滑动验证码识别
    VMWare虚拟机应用介绍
    Rpg maker mv角色扮演游戏制作大师简介
    python数据挖掘之数据探索第一篇
    python数据分析&挖掘,机器学习环境配置
    python爬取豆瓣视频信息代码
    python验证码处理(1)
  • 原文地址:https://www.cnblogs.com/monkeyT/p/9839150.html
Copyright © 2011-2022 走看看