zoukankan      html  css  js  c++  java
  • 理解 numpy.rollaxis() 函数

    函数声明

    先看看 numpy.rollaxis() 函数的定义形式,如下:

    rollaxis(a, axis, start=0)

    参数 a 通常为 numpy.ndarray 类型,则 a.ndim表示 numpy 数组的维数;

    参数 axis 通常为 int 类型,范围为 [0, a.ndim);

    参数 start 为 int 类型,默认值为 0,取值范围为 [-a.ndim, a.ndim],如果超过这个范围,则会 raise AxisError。

    函数功能

    numpy.rollaxis() 函数用于滚动(roll)指定轴(axis)到某位置。这个函数可以用更易理解的函数 numpy.moveaxis(a, source, destination) 代替。但由于 numpy.moveaxis() 函数是在 numpy v1.11 版本新增的,为了与之前的版本兼容,这个函数依旧保留。

    具体来说,需要根据 axis 和 normalized start 的比较结果,选择将 axis 滚动到哪个位置上,而其他轴的位置顺序不变。如果 axis 参数值大于或等于 normalized start,则 axis 从后向前滚动,直到 start 位置;如果 axis 参数值小于 normalized start,则 axis 轴从前往后滚动,直到 start 的前一个位置,即 start-1 位置。其中 start 和 normalized start 的对应关系,如下表所示:

    start

    Normalized start

    -(a.ndim+1)

    raise AxisError

    -a.ndim

    0

    -1

    a.ndim-1

    0

    0

    a.ndim

    a.ndim

    a.ndim+1

    raise AxisError

    从表中,可以看出 normalized start 是在 -a.ndim <= start < 0 时, start + a.ndim 的值;在  0 <= start <= a.ndim 时,start 值。

    具体的示例及解释,如下所示

    import numpy as np
    
    a = np.ones((3,4,5,6))
    
    axis, start = 3, 1
    # 因为 3 > 1,所以 axis index 3 移动到 axis index 1(start位置),而其他维度位置不变
    print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,6,4,5)
    # np.moveaxis 的等价调用
    print(np.moveaxis(a, source=axis, destination=start).shape)
    
    axis, start = 2, 0
    # 因为 2 > 0,所以 axis index 2 移动到 axis index 0(start位置),而其他维度位置不变
    print(np.rollaxis(a, axis, start).shape)  # (5,3,4,6)
    # np.moveaxis 的等价调用
    print(np.moveaxis(a, axis, start).shape)
    
    axis, start = 1, 4
    # 因为 1 < 4,所以 axis index 1 移动到 axis index 3(start-1位置),而其他维度位置不变
    print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,5,6,4)
    # np.moveaxis 的等价调用
    print(np.moveaxis(a, source=axis, destination=start-1).shape)

     为了更好理解这个过程,最后看看该函数在 numpy 中实现的核心代码,如下所示:

    def rollaxis(a, axis, start=0):
        """
        Roll the specified axis backwards, until it lies in a given position.
    
        Parameters
        ----------
        a : ndarray
            Input array.
        axis : int
            The axis to be rolled. The positions of the other axes do not
            change relative to one another.
        start : int, optional
            When ``start <= axis``, the axis is rolled back until it lies in
            this position. When ``start > axis``, the axis is rolled until it
            lies before this position. The default, 0, results in a "complete"
            roll. 
    
        Returns
        -------
        res : ndarray
            For NumPy >= 1.10.0 a view of `a` is always returned. For earlier
            NumPy versions a view of `a` is returned only if the order of the
            axes is changed, otherwise the input array is returned.
    
        """
        n = a.ndim
        if start < 0:
            start += n
        msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
        if not (0 <= start < n + 1):
            raise AxisError(msg % ('start', -n, 'start', n + 1, start))
        if axis < start:
            start -= 1
        if axis == start:
            return a[...]
        axes = list(range(0, n))
        axes.remove(axis)
        axes.insert(start, axis)
        return a.transpose(axes)

    参考资料

    [1] numpy.rollaxis API reference. https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html

  • 相关阅读:
    图解 Kubernetes
    如何构建可伸缩的Web应用?
    2020年软件开发趋势
    3种基础的 REST 安全机制
    为什么你应该使用 Kubernetes(k8s)
    Elasticsearch:是什么?你为什么需要他?
    你在使用什么 Redis 客户端工具?
    ZooKeeper 并不适合做注册中心
    Jmeter(三)_配置元件
    Jmeter(二)_基础元件
  • 原文地址:https://www.cnblogs.com/klchang/p/14459983.html
Copyright © 2011-2022 走看看