zoukankan      html  css  js  c++  java
  • pytorch之Tensor

    #tensor和numpy
    import torch
    import numpy as np

    numpy_tensor = np.random.randn(3,4)
    print(numpy_tensor)
    #将numpy的ndarray转换到tendor上
    pytorch_tensor1 = torch.Tensor(numpy_tensor)
    pytorch_tensor2 = torch.from_numpy(numpy_tensor)
    print(pytorch_tensor1)
    print(pytorch_tensor2)
    #将pytorch的tensor转换到numpy的ndarray
    numpy_array = pytorch_tensor1.numpy()   #如果pytorch在cpu上
    print(numpy_array)
    #tensor的一些属性,得到tensor的大小
    print(pytorch_tensor1.shape)
    print(pytorch_tensor1.size())
    print(pytorch_tensor1.type()) #得到tensor的数据类型
    print(pytorch_tensor1.dim()) #得到tensor的维度
    print(pytorch_tensor1.numel()) #得到tensor所有元素的个数

    x = torch.rand(3,2)
    x.type(torch.DoubleTensor)
    print(x)
    np_array = x.numpy()
    print(np_array.dtype)

    [[ 1.05174423  1.09272735  0.46027768 -0.03255727]
     [ 0.57027229  1.22165706 -0.77909099 -0.17678552]
     [ 0.02112402 -1.08971068  0.72317744 -1.45482622]]
    tensor([[ 1.0517,  1.0927,  0.4603, -0.0326],
            [ 0.5703,  1.2217, -0.7791, -0.1768],
            [ 0.0211, -1.0897,  0.7232, -1.4548]])
    tensor([[ 1.0517,  1.0927,  0.4603, -0.0326],
            [ 0.5703,  1.2217, -0.7791, -0.1768],
            [ 0.0211, -1.0897,  0.7232, -1.4548]], dtype=torch.float64)
    [[ 1.0517442   1.0927273   0.46027768 -0.03255726]
     [ 0.57027227  1.221657   -0.779091   -0.17678553]
     [ 0.02112402 -1.0897107   0.72317743 -1.4548262 ]]
    torch.Size([3, 4])
    torch.Size([3, 4])
    torch.FloatTensor
    2
    12
    tensor([[0.1810, 0.5168],
            [0.9859, 0.1294],
            [0.9262, 0.6952]])
    float32

    #Tensor的操作1
    import torch
    x = torch.ones(2,3)
    print(x)
    print(x.type())
    x = x.long()
    print(x.type())
    x = x.float()
    print(x.type())

    y = torch.rand(3,4)
    print(y)
    #沿着行取最大值
    maxval,maxindex = torch.max(y,dim=1)
    print(maxval,' ',maxindex)

    #沿着行对y求和
    sum = torch.sum(y,dim=1)
    print(sum)

    tensor([[1., 1., 1.],
            [1., 1., 1.]])
    torch.FloatTensor
    torch.LongTensor
    torch.FloatTensor
    tensor([[0.8910, 0.0130, 0.9600, 0.6760],
            [0.5184, 0.6240, 0.9589, 0.2151],
            [0.6904, 0.3474, 0.7502, 0.2055]])
    tensor([0.9600, 0.9589, 0.7502]) 
     tensor([2, 2, 2])
    tensor([2.5400, 2.3164, 1.9936])

    #Tensor操作2
    import torch

    x = torch.rand(3,2)
    print(x)
    print(x.size())
    #增加一个维度
    x = x.unsqueeze(0)
    print(x.size())
    #减少一个维度
    x = x.squeeze(0)
    print(x.size())
    #增加回来
    x = x.unsqueeze(1)
    print(x.size())
    #使用permute和transpose来对矩阵维度进行变换
    #permute 可以重新排列tensor的维度
    #transpose 可以交换两个维度
    x = x.permute(1,0,2)
    print(x.size())
    x = x.transpose(0,2)
    print(x.size())

    tensor([[0.9131, 0.2160],
            [0.0987, 0.5013],
            [0.1715, 0.8862]])
    torch.Size([3, 2])
    torch.Size([1, 3, 2])
    torch.Size([3, 2])
    torch.Size([3, 1, 2])
    torch.Size([1, 3, 2])
    torch.Size([2, 3, 1])

    #使用view对tensor进行reshape

    import torch
    x = torch.rand(3,4,5)
    print(x.shape)
    x = x.view(-1,5)
    print(x.size())
    x = x.view(60)
    print(x.shape)

    #两个Tensor求和
    a = torch.rand(3,4)
    b = torch.rand(3,4)
    c = a + b
    print(c)
    z = torch.add(a,b)
    print(z)

    torch.Size([3, 4, 5])
    torch.Size([12, 5])
    torch.Size([60])
    tensor([[0.8822, 1.3766, 1.3586, 0.8951],
            [1.0096, 0.5511, 0.2035, 0.9684],
            [1.2502, 0.0963, 1.3955, 0.9479]])
    tensor([[0.8822, 1.3766, 1.3586, 0.8951],
            [1.0096, 0.5511, 0.2035, 0.9684],
            [1.2502, 0.0963, 1.3955, 0.9479]])

    import torch
    x = torch.ones(4,4)
    print(x)
    x[1:3,1:3] = 2
    print(x)

    tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [1., 1., 1., 1.]])
    tensor([[1., 1., 1., 1.],
            [1., 2., 2., 1.],
            [1., 2., 2., 1.],
            [1., 1., 1., 1.]])
  • 相关阅读:
    线性参考
    unix下安装Server(静默方式)
    ArcGIS Server REST开发模式
    Python中调用AO
    Oracle 冷备份
    平头缓冲
    Oracle 热备份
    Socket获取远程连接者的IP
    c#调用cmd执行相关命令
    C#_winform_DataGridView_的18种常见属性 (转)
  • 原文地址:https://www.cnblogs.com/ryluo/p/10170687.html
Copyright © 2011-2022 走看看