zoukankan      html  css  js  c++  java
  • pytorch高阶操作

    pytorch高阶操作

    where函数

    torch.where(condition,x,y)

    可能新生成的tensor一部分来自x,一部分来自y,但是是没有规律的

    例子:假设一个tensor表示识别概率,大于0.5表示1,小于0.5表示0

    a = torch.rand(2,2)
    print(a)
    
    tensor([[0.9872, 0.9270],
            [0.6795, 0.0959]])
    
    
    aa = torch.zeros(2,2)
    bb = torch.ones(2,2)
    
    answer = torch.where(a>0.5,aa,bb)
    print(answer)
    
    tensor([[0., 0.],
            [0., 1.]])
    

    gather函数

    实际就是一个查表的函数

    比如像手写数字的识别,【4,10】4张图片,最后识别出每张图片中10个概率最大的index(一般index为几这个数字就是几),但是如果我们的标签不是1~10,而是另外有一张表来对应,不同的index对应不同的标签,这时就可以使用gather函数

    例子:

    prob = torch.rand(4,10)
    
    idx = prob.topk(3,dim=1)
    idx1 = idx[1]
    
    print(idx1)
    
    tensor([[1, 3, 4],
            [2, 0, 3],
            [5, 4, 2],
            [9, 4, 5]])
    
    label = torch.arange(10)+100#为了方面随便初始化的label
    
    print(torch.gather(label.expand(4,10),dim=1,index=idx1.long()))
    
    
    tensor([[101, 103, 104],
            [102, 100, 103],
            [105, 104, 102],
            [109, 104, 105]])
    
  • 相关阅读:
    数据库的基本操作
    这是数据库的知识了
    这就全都是了解的东西啦
    互斥锁
    我只会用threading,我菜
    violet
    网络编程II
    网络编程
    这是网络编程的一小步,却是我的一大步
    莫比乌斯反演(一)从容斥到反演
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13603199.html
Copyright © 2011-2022 走看看