zoukankan      html  css  js  c++  java
  • pytorch索引与切片

    @

    index索引

    torch会自动从左向右索引

    例子:

    a = torch.randn(4,3,28,28)
    

    表示类似一个CNN 的图片的输入数据,4表示这个batch一共有4张照片,而3表示图片的通道数为3(RGB),(28,28)表示图片的大小

    基本索引

    索引1:表示第零张图片的shape

    print(a[0].shape)
    #torch.Size([3,28,28])
    

    索引2:第零张图片的第零个通道的size

    print(a[0,0].shape)
    #torch.Size([28,28])
    

    索引3:表示第零张图片的第零个通道的第二行第四列的像素点的值

    print(a[0,0,2,4])
    #tensor(0.8082)
    

    连续选取

    ⭐索引4:连续取两张图片(取第0张以及第一张图片,不包括第二张)

    print(a[:2].shape
    #torch.Size([2,3,28,28])
    #由于是两张图片,所以第一维变为2
    

    ⭐索引5:前两张图片上的第一个通道上的数据(所以通道数变为了1)

    print(a[:2,:1,:,:].shape)
    print(a[:2,:1].shape)
    #torch.Size(2,1,28,28)
    

    ⭐索引6:从后面取(-1表示最后一个,从最后一个取到最后,也就是一个通道)

    print(a[:2,-1:,:,:].shape)
    
    #torch.Size(2,1,28,28)
    

    规则间隔索引

    ⭐索引7:在图片的矩阵进行隔行与隔列索引 0:28:2表示从0到28(不包括28),间隔数为2

    print(a[:,:,0:28:2,0:28:2].shape)
    print(a[:,:,::2,::2].shape)
    #torch.Size([4,3,14,14])
    

    索引总结

    start : end : step

    都取

    x:从x取到最后 :x 从开始取到x x:y从x取到y

    x:y:z从x到y每隔z个点采样一次

    不规则间隔索引

    使用index_select()函数

    第一个参数表示你对哪个维度进行操作;第二个参数是index(必须是tensor类型):对第0张与第2张图片进行操作

    a.index_select(0,torch.tensor([0,2])).shape
    #【2,3,28,28】
    

    同理:选择了两个通道

    a.index_select(1,torch.tensor([1,2])).shape
    #【4,2,28,28】
    

    同理:只取8行

    a.index_select(2,torch.arange(8)).shape
    #【4,2,8,28】
    

    任意多的维度索引

    使用符号:...

    例子:

    a[...].shape
    #[4,3,28,28]
    
    a[0,...].shape
    #[3,28,28]
    
    a[0,1,...].shape
    #[4,28,28]
    
    a[...,2].shape
    #[4,3,28,2]
    

    使用掩码来索引

    函数:.masked_select()会将筛选出来的元素打平(因为无法维护原来的shape)

    x = torch.randn(2,3)
    print(x)
    
    tensor([[-1.3081, -0.5651, -0.9843],
            [ 1.0051, -0.3829,  0.6300]])
    
    mask = x.ge(0.5)#大于等于0.5的元素
    print(mask)
    
    tensor([[False, False, False],
            [ True, False,  True]])
    
    z = torch.masked_select(x,mask)
    print(z)
    
    tensor([1.0051, 0.6300])
    
    

    打平后的索引

    例子:使用take函数:是将输入的tensor打平之后进行index的选择

    src = torch.tensor([[4,3,5],[6,7,8]])
    torch.take(src,torch.tensor([0,2,8]))
    #tensor([4,5,8])
    
  • 相关阅读:
    lnmp mysql高负载优化
    vi查找替换命令详解
    学习资料汇总
    App应用推广
    sed 命令详解
    10个经典的Android开源应用项目
    Java基础学习总结——Java对象的序列化和反序列化
    Java制作证书的工具keytool用法总结
    Linux下安装Tomcat服务器和部署Web应用
    谈谈对Spring IOC的理解
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13592020.html
Copyright © 2011-2022 走看看