zoukankan      html  css  js  c++  java
  • Pytorch-数据类型

    1.张量数据类型

    Pytorch常用的数据类型,其中FloatTensor、DoubleTensor、ByteTensor、IntTensor最常用。GPU和CPU的Tensor不相同。

    •  数据类型检查使用isinstance()
    import torch
    
    a = torch.randn(2,3)
    #torch.FloatTensor
    a.type()
    #true
    isinstance(a,torch.FloatTensor)
    • 标量,torch.tensor(),t是小写的
    import torch
    
    a = torch.tensor(2)
    #0
    print(len(a.shape))
    #torch.Size([])
    print(a.size())
    • 1维张量
    import torch
    import numpy as np
    
    #torch.tensor里边放的是数
    a = torch.tensor([2.1])
    print(a.shape)
    
    #torch.FloatTensor里边放的是维度
    a = torch.FloatTensor(2)
    #tensor([5.6052e-45, 0.0000e+00])
    print(a)
    
    a = np.ones(2)
    #tensor([1., 1.], dtype=torch.float64)
    b = torch.from_numpy(a)
    print(isinstance(b,torch.DoubleTensor))
    • 多维张量
    import torch
    import numpy as np
    
    # tensor([[[0.9539, 0.4338, 0.9842],
    #          [0.2288, 0.0569, 0.9997]]])
    a = torch.rand(1,2,3)
    #3,返回维度
    a.dim()
    #6,返回元素数,1*2*3
    a.numel()
    • 初始化张量
    import torch
    
    #(0,1)之间均值分布初始化
    a = torch.rand(3,3)
    # tensor([[0.7140, 0.3779, 0.7530],
    #         [0.1225, 0.2168, 0.9868],
    #         [0.6421, 0.0806, 0.1370]])
    print(a)
    
    #接收一个tensor,把a的shape读出来,生成一个a的shape的均值分布
    b = torch.rand_like(a)
    # tensor([[0.6055, 0.3282, 0.4211],
    #         [0.9757, 0.3171, 0.5054],
    #         [0.3429, 0.1091, 0.9734]])
    print(b)
    
    #生成1-10之间的整数,生成形状是[3,3]
    c = torch.randint(1,10,[3,3])
    # tensor([[6, 6, 9],
    #         [8, 1, 1],
    #         [9, 6, 8]])
    print(c)
    • torch.full()
    #生成一个2*3,元素都是7的tnensor
    #生成标量使用[]
    a = torch.full([2,3],7)
    print(a)
    • torch.arange()
    #生成[0,10)之间的等差数列
    #tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    a = torch.arange(0,10)
    #tensor([0, 2, 4, 6, 8])
    b = torch.arange(0,10,2)
    • torch.linspace()
    #[0,10]之间均匀取11个点
    #tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
    a = torch.linspace(0,10,steps=11)
    print(a)
    • torch.ones()、torch.zeros()、torch.eye()
    #单位阵
    a = torch.ones(3,3)#0矩阵
    b = torch.zeros(3,4)#对角阵
    c = torch.eye(3)

    2.索引

    索引是从第0维开始的

    import torch
    
    a = torch.rand(4,3,28,28)
    
    #torch.Size([3, 28, 28])
    print(a[0])
    #torch.Size([28, 28])
    print(a[0,1])
    • :选取
    import torch
    
    a = torch.rand(4,3,28,28)
    #选择0维度0,1两个元素,不包括2
    print(a[:2])
    #选择0维度最后一个元素
    #索引正编号[0,1,2,3],反编号[-4,-3,-2,-1]
    #等价于a[3]
    print(a[-1:])

    总的形式可以表达为:start:end:step,::2表示元素隔1个进行采样

    • 选择特定的行,index_select()
    import torch
    
    a = torch.rand(4,3,28,28)
    
    #第0个维度,选择特定的第0个和第三个元素
    b = a.index_select(0,torch.tensor([0,3]))
    • masked_select(),选择特定位置元素
    import torch
    
    # tensor([[0.1956, 0.1843, 0.2313],
    #         [0.1363, 0.4729, 0.7214],
    #         [0.5356, 0.4904, 0.5742]])
    a = torch.rand(3,3)
    #值大于0.5的元素位置设为1
    # tensor([[0, 0, 0],
    #         [0, 0, 1],
    #         [1, 0, 1]], dtype=torch.uint8)
    mask = a.ge(0.5)
    #tensor([0.7214, 0.5356, 0.5742])
    c = torch.masked_select(a,mask)

    2.维度变换

    • view()、reshape()
    import torch
    
    a = torch.rand(4,1,28,28)
    #torch.Size([4, 784])
    b = a.view(4,28*28)
    #view完以后会丢失a的shape信息
    #b.view(4,28,28,1),这样变换语法上没有错误,但是由于shape和a不匹配,变换的数据并不是原数据,造成数据污染
    • 增加维度,unsqueeze()
    import torch
    
    #torch.Size([2])
    a = torch.tensor([15,22])
    #正数表示在哪个维度之前添加一个维度,值<a.dim()+1
    #torch.Size([2, 1])
    # tensor([[15],
    #         [22]])
    b = a.unsqueeze(1)
    #torch.Size([1, 2])
    #tensor([[15, 22]])
    b = a.unsqueeze(0)
    
    #负数表示在哪个维度之后插入一个维度,值≥-a.dim()-1
    #torch.Size([2, 1])
    b=a.unsqueeze(-1)
    #torch.Size([1, 2])
    b=a.unsqueeze(-2)
    # 正, 0,1, 2, 3
    #    [4,3,28,28]
    # 负,-4,-3,-2,-1
    x = torch.rand(4,3,28,28)
    #等价x.unsqueeze(2)
    q = x.unsqueeze(-3)
    • squeeze(),维度减少
    import torch
    
    x = torch.rand(1,1,28,1)
    #不填参数会把所有为1的维度减少
    a = x.squeeze()
    #把第3维维度减少
    #torch.Size([1, 1, 28])
    b = x.squeeze(3)
    • expand(),维度扩展
    import torch
    
    #维度拓展可以把所有是1的维度扩展到需要维度
    x = torch.tensor([[1,2]])
    # tensor([[1, 2],
    #         [1, 2]])
    x = x.expand(2,2)
    y = torch.tensor([[3,4],[5,6]])
    # tensor([[4, 6],
    #         [6, 8]])
    a = x +y
    • transpose(),交换行列

    •  permute(),把原来的行列交换
    import torch
    
    x = torch.rand(4,3,28,28)
    #交换0维和1维
    #torch.Size([3, 4, 28, 28])
    b = x.permute(1,0,2,3)

    permute和transpose会打乱数据在内存中的位置,如果数据在内存中不连续了,使用contiguous()把数据变成连续的

  • 相关阅读:
    mysql基础-01
    Delphi 和键盘有关的API函数(Keyboard Input)
    Delphi System单元-Odd- 判断是否是奇数
    Delphi 键盘API GetKeyState、GetAsyncKeyState -获取键盘 / 按键值Key的状态
    Delphi 全局热键KeyPress 和 热键 API(RegisterHotKey、UnRegisterHotKey、GlobalAddAtom、GlobalDeleteAtom、GlobalFindAtom)
    Delphi XE Android platform uses-permission[2] AndroidManifest.xml 配置
    Delphi XE Android platform uses-permission[1] 权限列表
    Delphi XE Andriod 文件后缀对应MIME类型
    Delphi XE RTL Androidapi 单元
    Delphi XE FMX TFmxObject 类 和 单元
  • 原文地址:https://www.cnblogs.com/vshen999/p/12145768.html
Copyright © 2011-2022 走看看