zoukankan      html  css  js  c++  java
  • Pytorch-tensor的分割,属性统计

    1.矩阵的分割

    方法:split(分割长度,所分割的维度)split([分割所占的百分比],所分割的维度)
    a=torch.rand(32,8)
    aa,bb=a.split(16,dim=0)
    print(aa.shape)
    print(bb.shape)
    cc,dd=a.split([20,12],dim=0)
    print(cc.shape)
    print(dd.shape)
    

    输出结果

    torch.Size([16, 8])
    torch.Size([16, 8])
    torch.Size([20, 8])
    torch.Size([12, 8])
    

    2.tensor的属性统计

    min(dim=1):返回第一维的所有最小值,以及下标
    max(dim=1):返回第一维的所有最大值,以及下标
    a=torch.rand(4,3)
    print(a,'
    ')
    print(a.min(dim=1),'
    ')
    print(a.max(dim=1))
    

    输出结果

    tensor([[0.3876, 0.5638, 0.5768],
            [0.7615, 0.9885, 0.9660],
            [0.3622, 0.4334, 0.1226],
            [0.9390, 0.6292, 0.8370]]) 
            
    torch.return_types.min(
    values=tensor([0.3876, 0.7615, 0.1226, 0.6292]),
    indices=tensor([0, 0, 2, 1])) 
    
    torch.return_types.max(
    values=tensor([0.5768, 0.9885, 0.4334, 0.9390]),
    indices=tensor([2, 1, 1, 0]))
    
    
    mean:求平均值
    prod:求累乘
    sum:求累加
    argmin:求最小值下标
    argmax:求最大值下标
    a=torch.rand(1,3)
    print(a)
    print(a.mean())
    print(a.prod())
    print(a.sum())
    print(a.argmin())
    print(a.argmax())
    

    输出结果

    tensor([[0.5366, 0.9145, 0.0606]])
    tensor(0.5039)
    tensor(0.0297)
    tensor(1.5117)
    tensor(2)
    tensor(1)
    

    3.tensor的topk()和kthvalue()

    topk(k,dim=a,largest=):输出维度为1的前k大的值,以及它们的下标。
    kthvalue(k,dim=a):输出维度为a的第k小的值,并输出它的下标。
    a=torch.rand(4,4)
    print(a,'
    ')
    # 输出每一行中2个最大的值,并输出它们的下标
    print(a.topk(2,dim=1),'
    ')
    
    # 输出每一行中3个最小的值,并输出它们的下标
    print(a.topk(3,dim=1,largest=False),'
    ')
    
    # 输出每一行第2小的值,并输出下标
    print(a.kthvalue(2,dim=1))
    

    输出结果

    tensor([[0.7131, 0.8148, 0.8036, 0.4720],
            [0.9135, 0.4639, 0.5114, 0.2277],
            [0.1314, 0.8407, 0.7990, 0.9426],
            [0.6556, 0.7316, 0.9648, 0.9223]]) 
    
    torch.return_types.topk(
    values=tensor([[0.8148, 0.8036],
            [0.9135, 0.5114],
            [0.9426, 0.8407],
            [0.9648, 0.9223]]),
    indices=tensor([[1, 2],
            [0, 2],
            [3, 1],
            [2, 3]])) 
    
    torch.return_types.topk(
    values=tensor([[0.4720, 0.7131, 0.8036],
            [0.2277, 0.4639, 0.5114],
            [0.1314, 0.7990, 0.8407],
            [0.6556, 0.7316, 0.9223]]),
    indices=tensor([[3, 0, 2],
            [3, 1, 2],
            [0, 2, 1],
            [0, 1, 3]])) 
    
    torch.return_types.kthvalue(
    values=tensor([0.7131, 0.4639, 0.7990, 0.7316]),
    indices=tensor([0, 1, 2, 1]))
    
  • 相关阅读:
    我的航拍直升机 控制基站软件的编写历程(2.2)——Qt Creator 版本控制系统
    windows下QT开发环境建立方法
    QT 4.5 windows版本 安装问题 及 Junction 使用
    各种平台下编译qt工程
    华为面试题
    strcpy,strncpy,strlcpy,memcpy
    QT/E 更换字体问题
    Linux设备驱动编程中断处理
    oracle数据库连接池的使用
    Windows下QT的安装
  • 原文地址:https://www.cnblogs.com/52dxer/p/13779473.html
Copyright © 2011-2022 走看看