zoukankan      html  css  js  c++  java
  • Pytorch的tensor数据类型

    基本类型

    torch.Tensor是一种包含单一数据类型元素的多维矩阵。

    Torch定义了七种CPU tensor类型和八种GPU tensor类型:

    Data tyoe CPU tensor GPU tensor
    32-bit floating point torch.FloatTensor torch.cuda.FloatTensor
    64-bit floating point torch.DoubleTensor torch.cuda.DoubleTensor
    16-bit floating point N/A torch.cuda.HalfTensor
    8-bit integer (unsigned) torch.ByteTensor torch.cuda.ByteTensor
    8-bit integer (signed) torch.CharTensor torch.cuda.CharTensor
    16-bit integer (signed) torch.ShortTensor torch.cuda.ShortTensor
    32-bit integer (signed) torch.IntTensor torch.cuda.IntTensor
    64-bit integer (signed) torch.LongTensor torch.cuda.LongTensor

    torch.DoubleTensor(2, 2) 构建一个22 Double类型的张量
    torch.ByteTensor(2, 2) 构建一个2
    2 Byte类型的张量
    torch.CharTensor(2, 2) 构建一个22 Char类型的张量
    torch.ShortTensor(2, 2) 构建一个2
    2 Short类型的张量
    torch.IntTensor(2, 2) 构建一个22 Int类型的张量
    torch.LongTensor(2, 2) 构建一个2
    2 Long类型的张量

    类型转换

    2.1 CPU和GPU的Tensor之间转换

    从cpu –> gpu,使用data.cuda()即可。
    若从gpu –> cpu,则使用data.cpu()。

    2.2 Tensor与Numpy Array之间的转换

    Tensor –> Numpy.ndarray 可以使用 data.numpy(),其中data的类型为torch.Tensor。
    Numpy.ndarray –> Tensor 可以使用torch.from_numpy(data),其中data的类型为numpy.ndarray。

    2.3 Tensor的基本类型转换(也就是float转double,转byte这种。)

    为了方便测试,我们构建一个新的张量,你要转变成不同的类型只需要根据自己的需求选择即可

    1. tensor = torch.Tensor(2, 5)

    2. torch.long() 将tensor投射为long类型
      newtensor = tensor.long()

    3. torch.half()将tensor投射为半精度浮点(16位浮点)类型
      newtensor = tensor.half()

    4. torch.int()将该tensor投射为int类型
      newtensor = tensor.int()

    5. torch.double()将该tensor投射为double类型
      newtensor = tensor.double()

    6. torch.float()将该tensor投射为float类型
      newtensor = tensor.float()

    7. torch.char()将该tensor投射为char类型
      newtensor = tensor.char()

    8. torch.byte()将该tensor投射为byte类型
      newtensor = tensor.byte()

    9. torch.short()将该tensor投射为short类型
      newtensor = tensor.short()

    如果当你需要提高精度,比如说想把模型从float变为double。那么可以将要训练的模型设置为model = model.double()。此外,还要对所有的张量进行设置:pytorch.set_default_tensor_type('torch.DoubleTensor'),不过double比float要慢很多,要结合实际情况进行思考。

  • 相关阅读:
    leetcode 13. Roman to Integer
    python 判断是否为有效域名
    leetcode 169. Majority Element
    leetcode 733. Flood Fill
    最大信息系数——检测变量之间非线性相关性
    leetcode 453. Minimum Moves to Equal Array Elements
    leetcode 492. Construct the Rectangle
    leetcode 598. Range Addition II
    leetcode 349. Intersection of Two Arrays
    leetcode 171. Excel Sheet Column Number
  • 原文地址:https://www.cnblogs.com/icodeworld/p/11882263.html
Copyright © 2011-2022 走看看