zoukankan      html  css  js  c++  java
  • 0402-Tensor和Numpy的区别

    0402-Tensor和Numpy的区别

    pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

    由于tensor和ndarray具有很高的相似性,并且两者相互转化需要的开销很小。但是由于ndarray出现时间较早,相比较tensor有更多更简便的方法,因此在某些时候tensor无法实现某些功能,可以把tensor转换为ndarray格式进行处理后再转换为tensor格式。

    一、tensor数据和ndarray数据相互转换

    import numpy as np
    
    a = np.ones([2, 3], dtype=np.float32)
    a
    
    array([[1., 1., 1.],
           [1., 1., 1.]], dtype=float32)
    
    
    
    b = t.from_numpy(a)  # 把ndarray数据转换为tensor数据
    b
    
    tensor([[1., 1., 1.],
            [1., 1., 1.]])
    
    b = t.Tensor(a)  # 把ndarray数据转换为tensor数据
    b
    
    tensor([[1., 1., 1.],
            [1., 1., 1.]])
    
    a[0, 1] = 100
    b
    
    tensor([[  1., 100.,   1.],
            [  1.,   1.,   1.]])
    
    c = b.numpy()  # 把tensor数据转换为ndarray数据
    c
    
    array([[  1., 100.,   1.],
           [  1.,   1.,   1.]], dtype=float32)
    

    二、广播法则

    广播法则来源于numpy,它的定义如下:

    • 让所有输入数组都向其中shape最长的数组看齐,shape中不足部分通过在前面加1补齐
    • 两个数组要么在某一个维度的长度一致,要么其中一个为1,否则不能计算
    • 当输入数组的某个维度的长度为1时,计算时沿此维度复制扩充×一样的形状

    torch当前支持自动广播法则,但更推荐使用以下两个方法进行手动广播,这样更直观,更不容出错:

    1. unsqueeze或view:为数据某一维的形状补1
    2. expand或expand_as:重复数组,实现当输入的数组的某个维度的长度为1时,计算时沿此维度复制扩充成一样的形状

    注:repeat与expand功能相似,但是repeat会把相同数据复制多份,而expand不会占用额外空间,只会在需要的时候才扩充,可以极大地节省内存。

    a = t.ones(3, 2)
    b = t.zeros(2, 3, 1)
    

    自动广播法则:

    1. a是二维,b是三维,所在现在较小的a前面补1(等价于a.unsqueeze(0),a的形状变成(0,2,3))
    2. 由于a和b在第一维和第三维的形状不一样,利用广播法则,两个形状都变成了(2,3,2)
    a + b
    
    tensor([[[1., 1.],
             [1., 1.],
             [1., 1.]],
    
            [[1., 1.],
             [1., 1.],
             [1., 1.]]])
    

    对上述自动广播可以通过以下方法实现手动广播

    a.unsqueeze(0).expand(2, 3, 2) + b.expand(
        2, 3, 2)  # 等价于a.view(1,3,2).expand(2,3,2) + b.expand(2,3,2)
    
    tensor([[[1., 1.],
             [1., 1.],
             [1., 1.]],
    
            [[1., 1.],
             [1., 1.],
             [1., 1.]]])
  • 相关阅读:
    centos7安装zabbix4.2
    python3.x 基础三:文件IO
    python3.x 基础三:字符集问题
    python3.x 基础三:set集合
    python3.x 基础二:内置函数
    python3.x 基础一:dict字典
    python3.x 基础一:str字符串方法
    python3.x 基础一
    [leetcode]Path Sum
    [leetcode]Balanced Binary Tree
  • 原文地址:https://www.cnblogs.com/nickchen121/p/14697550.html
Copyright © 2011-2022 走看看