zoukankan      html  css  js  c++  java
  • DLPack构建跨框架的深度学习编译器

    DLPack构建跨框架的深度学习编译器

    Tensorflow,PyTorch和ApacheMxNet等深度学习框架提供了一个功能强大的工具包,可用于快速进行原型设计和部署深度学习模型。易用性通常是以碎片为代价的:孤立地使用每个框架是很容易的。垂直集成已使常见用例的开发流程简化了,但是冒险走过的路可能很棘手。

    一个支持不佳的方案是将张量直接从一个框架传递到内存中的另一个框架,而没有任何数据重复或复制。支持这种用例使用户能够将管道串联在一起,其中某些算子在一个框架中得到比在另一个框架中得到更好的支持(或更快速)。框架之间共享的数据表示形式也将弥合这一差距,并在为算子生成代码时,允许编译器堆栈以单一格式为目标。

    DLPack是用于张量数据结构的中间内存表示标准。使用DLPack作为通用表示,传统上只能依赖供应商提供的库的框架编写的脚本中利用TVM。TVM打包函数可以在DLPack张量上运行,提供包装程序以桥接带有零数据副本的框架(例如PyTorch和MxNet)中的张量数据结构。

    DLPack提供了一种简单的可移植内存数据结构:

    /*!
     * rief Plain C Tensor object, does not manage memory.
     */
    typedef struct {
      /*!
       * rief The opaque data pointer points to the allocated data.
       *  This will be CUDA device pointer or cl_mem handle in OpenCL.
       *  This pointer is always aligns to 256 bytes as in CUDA.
       */
      void* data;
      /*! rief The device context of the tensor */
      DLContext ctx;
      /*! rief Number of dimensions */
      int ndim;
      /*! rief The data type of the pointer*/
      DLDataType dtype;
      /*! rief The shape of the tensor */
      int64_t* shape;
      /*!
       * rief strides of the tensor,
       *  can be NULL, indicating tensor is compact.
       */
      int64_t* strides;
      /*! rief The offset in bytes to the beginning pointer to data */
      uint64_t byte_offset;
    } DLTensor;

    例如,在TVM中声明并编译一个矩阵乘法算子,并构建一个使用DLPack表示形式的包装器wrapper,允许该算子支持PyTorch张量。还使用MxNet重复此演示。此扩展使机器学习开发人员可以在不牺牲性能的情况下,将代码快速移植到相对不受支持的硬件平台上。

    DLPack如何提供框架和TVM之间共享的中间包wrapper的说明:

     图1

    首先,在PyTorch中计算参考输出:

        import torch
        x = torch.rand(56,56)
        y = torch.rand(56,56)
        z = x.mm(y)

    然后,使用默认调度定义并构建TVM矩阵乘法算子:

        n = tvm.convert(56)
        X = tvm.placeholder((n,n), name='X')
        Y = tvm.placeholder((n,n), name='Y')
     
        k = tvm.reduce_axis((0, n), name='k')
        Z = tvm.compute((n,n), lambda i,j : tvm.sum(X[i,k]*Y[k,j], axis=k))
        s = tvm.create_schedule(Z.op)
        fmm = tvm.build(s, [X, Y, Z], target_host='llvm', name='fmm')

    为简便起见,没有涵盖可用于优化矩阵乘法的TVM大量的调度原语集合。如果希望使自定义GEMM算子在的硬件设备上快速运行,请参考详细的教程。

    然后,将TVM函数转换为支持PyTorch张量的函数:

        from tvm.contrib.dlpack import to_pytorch_func
        # fmm is the previously built TVM function (Python function)
        # fmm is the wrapped TVM function (Python function)
        fmm_pytorch = to_pytorch_func(fmm)
        z2 = torch.empty(56,56)
        fmm_pytorch(x, y, z2)
        np.testing.assert_allclose(z.numpy(), z2.numpy())

    并验证结果是否匹配。

    可以重复相同的示例,但是使用MxNet代替:

        import mxnet
        from tvm.contrib.mxnet import to_mxnet_func
        ctx = mxnet.cpu(0)
        x = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
        y = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
        z = mxnet.nd.empty(shape=(56,56), ctx=ctx)
        f = tvm.build(s, [X, Y, Z], target_host='llvm', name='f')
        f_mxnet = to_mxnet_func(f)
        f_mxnet(x, y, z)
        np.testing.assert_allclose(z.asnumpy(), x.asnumpy().dot(y.asnumpy()))

    在PyTorch示例的幕后

    由于TVM提供了将dlpack张量转换为tvm的功能NDArray反之亦然,因此,通过wrapper功能,所需的只是一些语法 syntactic sugar 。 convert_func是用于使用具有dlpack支持的张量的框架的通用转换器,可以用于实现方便的转换器,例如 to_pytorch_func

    def convert_func(tvm_func, tensor_type, to_dlpack_func):
        assert callable(tvm_func)
     
        def _wrapper(*args):
            args = tuple(ndarray.from_dlpack(to_dlpack_func(arg))
                if isinstance(arg, tensor_type) else arg for arg in args)
            return tvm_func(*args)
     
        return _wrapper
     
    def to_pytorch_func(tvm_func):
        import torch
        import torch.utils.dlpack
        return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)

     

    人工智能芯片与自动驾驶
  • 相关阅读:
    重构第四天 : 用多态替换条件语句(if else & switch)
    MSBuild 教程(2)
    为什么Nhibernate中属性和方法必须Virtual的
    重构第三天:提升方法&下移方法
    重构第二天:移动方法
    重构第一天:封装集合
    MSbuild 教程
    工程经验总结之吹水"管理大境界"
    呕心沥血之作:完美解决Informix的中文乱码问题
    万事开头难——我的蛮荒时代
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14503412.html
Copyright © 2011-2022 走看看