zoukankan      html  css  js  c++  java
  • pytorch记录

    有两个tensor是A和B

    C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼) C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
     A = torch.ones(2,3)
     B = torch.ones(4,3)
     out=torch.cat((A,B),0)
    tensor([[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]])
    
    
    C = torch.ones(2,5)
    out = torch.cat((A,C),1)
    tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.]])
     max_test = torch.Tensor([[5,8,1],[3,1,9]])
    tensor([[5., 8., 1.],
            [3., 1., 9.]])
    
    max_test.max(1,keepdim=True)
    values=tensor([[8.],
            [9.]]),
    indices=tensor([[1],
            [2]]))
    
     max_test.max(1)
    torch.return_types.max(
    values=tensor([8., 9.]),
    indices=tensor([1, 2]))
    
    max_test.max(0)
    values=tensor([5., 8., 9.]),
    indices=tensor([0, 0, 1]))
    
    max_test.max(0,keepdim=True)
    torch.return_types.max(
    values=tensor([[5., 8., 9.]]),
    indices=tensor([[0, 0, 1]]))
    valid_idx = torch.tensor([True, False, True, False, False]) #小写的t,long类型
    a = torch.tensor([1,2,3,4,5])
    idx_filter = a[valid_idx]
    tensor([1, 3])
    b = torch.Tensor([[1,2,3]])
    b.squeeze(0)
     b
    tensor([[1., 2., 3.]])
    
    b.squeeze_(0)
    b
    tensor([1., 2., 3.])
    a = torch.ones(3,5)
    index = torch.tensor([0,2])
    a.index_fill_(0,index,100)
    tensor([[100., 100., 100., 100., 100.],
            [  1.,   1.,   1.,   1.,   1.],
            [100., 100., 100., 100., 100.]])
    
    
    b = torch.ones(3,5)
    b.index_fill(1,index,200)
    tensor([[200.,   1., 200.,   1.,   1.],
            [200.,   1., 200.,   1.,   1.],
            [200.,   1., 200.,   1.,   1.]])
     labels= torch.rand(5,4)
    tensor([[0.2833, 0.7600, 0.6912, 0.5421],
            [0.3498, 0.0440, 0.3356, 0.5975],
            [0.9071, 0.2023, 0.9391, 0.2516],
            [0.9536, 0.0939, 0.4833, 0.7402],
            [0.2392, 0.7111, 0.9192, 0.5417]])
     best_idx = torch.tensor([3,3,3,0,0,0,0])
    labels[best_idx]
    tensor([[0.9536, 0.0939, 0.4833, 0.7402],
            [0.9536, 0.0939, 0.4833, 0.7402],
            [0.9536, 0.0939, 0.4833, 0.7402],
            [0.2833, 0.7600, 0.6912, 0.5421],
            [0.2833, 0.7600, 0.6912, 0.5421],
            [0.2833, 0.7600, 0.6912, 0.5421],
            [0.2833, 0.7600, 0.6912, 0.5421]])
  • 相关阅读:
    RabbitMQ官方文档翻译之Simple(一)
    rabbitMq集成Spring后,消费者设置手动ack,并且在业务上控制是否ack
    RabbitMQ消息队列知识点归纳
    理解Java中HashMap的工作原理
    mybatis 主键回显
    quart任务调度框架实战
    springmvc常用注解标签详解
    Java程序员玩Linux学操作系统
    在网页中发起QQ临时对话的方法
    软件测试技术学习总结
  • 原文地址:https://www.cnblogs.com/crazybird123/p/14686357.html
Copyright © 2011-2022 走看看