zoukankan      html  css  js  c++  java
  • pytorch基础2

    下面是常见函数的代码例子

      1 import torch
      2 import numpy as np
      3 print("分割线-----------------------------------------")
      4 #加减乘除操作
      5 a = torch.rand(3,4)
      6 b = torch.rand(4)
      7 print(a)
      8 print(b)
      9 print(torch.add(a, b))
     10 print(torch.sub(a, b))
     11 print(torch.mul(a, b))
     12 print(torch.div(a, b))
     13 print(torch.all(torch.eq(a - b,torch.sub(a,b))))#判断torch的减法和python的减法结果是否一致
     14 print("分割线-----------------------------------------")
     15 #矩阵乘法(点乘和叉乘)matmul  mm  @   *
     16 a = torch.ones(2,2)*3
     17 b = torch.ones(2,2)
     18 print(a*b)#点乘积
     19 print(a.matmul(b))#叉乘积
     20 print(a@b)#叉乘积
     21 print(a.mm(b))#叉乘积,相比于前两种,这一种只能适合二维数组的乘积
     22 a = torch.rand(4,3,28,64)
     23 b = torch.rand(4,3,64,32)
     24 #torch.mm(a,b).shape#此时会报错,mm只适合二维
     25 print(torch.matmul(a,b).shape)#torch.Size([4, 3, 28, 32])
     26 b = torch.rand(4,1,64,32)
     27 torch.matmul(a, b).shape #torch.Size([4, 3, 28, 32])
     28 b = torch.rand(4,64,32)
     29 #torch.matmul(a, b).shape ,报错,因为b的4对应a的3无法进行广播,所以报错
     30 print("分割线-----------------------------------------")
     31 #power的使用
     32 a = torch.full([2,2],3)
     33 print(a.pow(2))
     34 print(a**2)
     35 aa = a**2
     36 print(aa.sqrt())
     37 print(aa**(0.5))
     38 print(aa.rsqrt())#开根号后的倒数
     39 print("分割线-----------------------------------------")
     40 #floor(),ceil(),round(),trunc(),frac()的使用
     41 a = torch.tensor(3.14)
     42 print(a.floor(),a.ceil(),a.trunc(),a.frac())#后两个是取整,和取小数
     43 print(a.round()) #四舍五入
     44 print("分割线-----------------------------------------")
     45 #clamp 和 dim,keepdim
     46 grad = torch.rand(2,3)*15
     47 print(grad.max(),grad.median(), grad.min())
     48 print(grad)
     49 print(grad.clamp(10))#小于10的都变成10
     50 print(grad.clamp(3,10))#不在3到10之间的变为3或者10
     51 a = torch.randn(4,10)
     52 print(a.max(dim=1))#返回每一列最大值组成的数组和对应的下标
     53 print(a.argmax(dim = 1))#这个只返回最大值对应的下标
     54 print(a.max(dim=1,keepdim=True))#keepdim的作用是使返回值维度是否仍然为原来的维度不变
     55 print(a.argmax(dim=1,keepdim=True))
     56 print("分割线-----------------------------------------")
     57 #topk和kthvalue
     58 print(a.topk(3,dim=1))#返回每行最大的三个数和下标
     59 print(a.topk(3,dim=1,largest=False))#返回每行最小的三个数和下标
     60 print(a.kthvalue(5,dim=1))#返回每行第五大的数字和对应数字的下标
     61 print("分割线-----------------------------------------")
     62 #矩阵的比较,cat的使用
     63 m = torch.rand(2,2)
     64 n = torch.rand(2,2)
     65 print(m,n)
     66 print(m == n)
     67 print(m.eq(n))
     68 print(m > n)
     69 a =  torch.rand(4,32,8)
     70 b = torch.rand(5,32,8)
     71 print(torch.cat([a,b],dim = 0).shape)#按行进行拼接,torch.Size([9, 32, 8])
     72 #更详细的拼接可以看下面的图
     73 a1 = torch.rand(4,3,32,32)
     74 a2 = torch.rand(4,1,32,32)
     75 #torch.cat([a1,a2],dim = 0).shape#报错原因是如果进行维度0上进行拼接,则要保证其他维度必须一致
     76 a1 = torch.rand(4,3,14,32)
     77 a2 = torch.rand(4,3,14,32)
     78 print(torch.cat([a1,a2],dim=2).shape)#torch.Size([4, 3, 28, 32])
     79 print("分割线-----------------------------------------")
     80 #stack和split的使用
     81 #用来进行维度的扩充,这个就是在dim =2进行扩充
     82 print(torch.stack([a1,a2],dim=2).shape)#torch.Size([4, 3, 2, 14, 32])
     83 aa , bb =a1.split([2,2],dim=0)#拆分成两份每份数目是2,2
     84 print(aa.shape,bb.shape)
     85 aaa,bbb = a1.split(2,dim=0)#每份长度为2
     86 print(aaa.shape,bbb.shape)
     87 aa,bb = a1.chunk(2,dim = 0)#拆成两块,每块一个
     88 print("分割线-----------------------------------------")
     89 #where,gather的使用
     90 cond = torch.rand(2,2)
     91 print(cond)
     92 a = torch.zeros(2,2)
     93 b = torch.ones(2,2)
     94 s = torch.where(cond>0.5,a,b)
     95 print(s)#如果大于0.5对应位置为a的对应位置的值,否则为b的对应位置的值
     96 prob = torch.randn(4,10)
     97 idx = prob.topk(dim =1,k=3)
     98 id = idx[1]
     99 print(id)#索引下标
    100 label= torch.arange(10)+100
    101 d = torch.gather(label.expand(4,10),dim =1,index = id)#获取对应索引下标的值
    102 print(d)

     运行结果如下

    D:anacondaanacondapythonw.exe D:/Code/Python/龙良曲pytorch学习/高级操作.py
    分割线-----------------------------------------
    tensor([[0.5581, 0.2369, 0.1379, 0.3702],
            [0.1565, 0.1022, 0.5839, 0.1778],
            [0.0204, 0.1498, 0.5276, 0.4219]])
    tensor([0.7969, 0.9313, 0.0608, 0.0245])
    tensor([[1.3551, 1.1682, 0.1988, 0.3947],
            [0.9535, 1.0335, 0.6448, 0.2023],
            [0.8173, 1.0811, 0.5884, 0.4464]])
    tensor([[-0.2388, -0.6944,  0.0771,  0.3457],
            [-0.6404, -0.8291,  0.5231,  0.1533],
            [-0.7766, -0.7815,  0.4667,  0.3974]])
    tensor([[0.4448, 0.2206, 0.0084, 0.0091],
            [0.1247, 0.0952, 0.0355, 0.0044],
            [0.0162, 0.1395, 0.0321, 0.0103]])
    tensor([[ 0.7003,  0.2544,  2.2669, 15.1075],
            [ 0.1964,  0.1097,  9.5973,  7.2539],
            [ 0.0255,  0.1609,  8.6706, 17.2148]])
    tensor(True)
    分割线-----------------------------------------
    tensor([[3., 3.],
            [3., 3.]])
    tensor([[6., 6.],
            [6., 6.]])
    tensor([[6., 6.],
            [6., 6.]])
    tensor([[6., 6.],
            [6., 6.]])
    torch.Size([4, 3, 28, 32])
    分割线-----------------------------------------
    tensor([[9., 9.],
            [9., 9.]])
    tensor([[9., 9.],
            [9., 9.]])
    tensor([[3., 3.],
            [3., 3.]])
    tensor([[3., 3.],
            [3., 3.]])
    tensor([[0.3333, 0.3333],
            [0.3333, 0.3333]])
    分割线-----------------------------------------
    tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
    tensor(3.)
    分割线-----------------------------------------
    tensor(14.8811) tensor(8.5843) tensor(5.4463)
    tensor([[10.3914, 14.8811,  8.5843],
            [10.6012,  5.4463,  5.7588]])
    tensor([[10.3914, 14.8811, 10.0000],
            [10.6012, 10.0000, 10.0000]])
    tensor([[10.0000, 10.0000,  8.5843],
            [10.0000,  5.4463,  5.7588]])
    torch.return_types.max(
    values=tensor([1.1859, 0.7394, 1.2261, 0.5407]),
    indices=tensor([5, 1, 2, 4]))
    tensor([5, 1, 2, 4])
    torch.return_types.max(
    values=tensor([[1.1859],
            [0.7394],
            [1.2261],
            [0.5407]]),
    indices=tensor([[5],
            [1],
            [2],
            [4]]))
    tensor([[5],
            [1],
            [2],
            [4]])
    分割线-----------------------------------------
    torch.return_types.topk(
    values=tensor([[ 1.1859,  0.8406,  0.7883],
            [ 0.7394,  0.4172,  0.2871],
            [ 1.2261,  0.9851,  0.9759],
            [ 0.5407,  0.1773, -0.2789]]),
    indices=tensor([[5, 4, 7],
            [1, 2, 4],
            [2, 8, 4],
            [4, 1, 8]]))
    torch.return_types.topk(
    values=tensor([[-1.7351, -0.3469, -0.3116],
            [-1.8399, -1.1521, -0.3790],
            [-1.3753, -0.6663, -0.2762],
            [-1.6875, -1.5461, -0.9697]]),
    indices=tensor([[0, 8, 6],
            [3, 5, 0],
            [7, 1, 5],
            [0, 2, 3]]))
    torch.return_types.kthvalue(
    values=tensor([-0.1758,  0.0470, -0.2039, -0.6223]),
    indices=tensor([2, 9, 3, 6]))
    分割线-----------------------------------------
    tensor([[0.9107, 0.4905],
            [0.6499, 0.3425]]) tensor([[0.6911, 0.9619],
            [0.1428, 0.5437]])
    tensor([[False, False],
            [False, False]])
    tensor([[False, False],
            [False, False]])
    tensor([[ True, False],
            [ True, False]])
    torch.Size([9, 32, 8])
    torch.Size([4, 3, 28, 32])
    分割线-----------------------------------------
    torch.Size([4, 3, 2, 14, 32])
    torch.Size([2, 3, 14, 32]) torch.Size([2, 3, 14, 32])
    torch.Size([2, 3, 14, 32]) torch.Size([2, 3, 14, 32])
    分割线-----------------------------------------
    tensor([[0.7541, 0.3861],
            [0.9605, 0.7175]])
    tensor([[0., 1.],
            [0., 0.]])
    tensor([[5, 4, 6],
            [8, 2, 3],
            [8, 6, 4],
            [6, 2, 1]])
    tensor([[105, 104, 106],
            [108, 102, 103],
            [108, 106, 104],
            [106, 102, 101]])
    
    Process finished with exit code 0
    

      

  • 相关阅读:
    音视频之音频(三)
    音视频之声音(二)
    音视频之图片(一)
    页面错位问题
    苹果账号恢复
    js使用逗号拼接id并去重
    Nginx常用命令
    java拼接字符串、格式化字符串方式
    Ajax 请求
    raw.githubusercontent.com port 443: Connection refused
  • 原文地址:https://www.cnblogs.com/henuliulei/p/11823620.html
Copyright © 2011-2022 走看看