zoukankan      html  css  js  c++  java
  • 使用元组输入进行计算和归约

    使用元组输入进行计算和归约

    在一个循环中计算出具有相同形状的多个输出,或者执行涉及多个值的归约,例如 argmax。这些问题可以通过元组输入解决。

    本文将介绍TVM中元组输入的用法。

    from __future__ import absolute_import, print_function
     
    import tvm
    from tvm import te
    import numpy as np

    描述Batchwise分批计算

    对于形状相同的算子te.compute,如果希望在下一个调度过程中一起调度,可以将它们放在一起作为输入。

    n = te.var("n")
    m = te.var("m")
    A0 = te.placeholder((m, n), name="A0")
    A1 = te.placeholder((m, n), name="A1")
    B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B")
     
    # The generated IR code would be:
    s = te.create_schedule(B0.op)
    print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))

    输出:

    primfn(A0_1: handle, A1_1: handle, B.v0_1: handle, B.v1_1: handle) -> ()
      attr = {"global_symbol": "main", "tir.noalias": True}
      buffers = {B.v1: Buffer(B.v1_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
                 B.v0: Buffer(B.v0_2: Pointer(float32), float32, [m, n], [stride_2: int32, stride_3: int32], type="auto"),
                 A1: Buffer(A1_2: Pointer(float32), float32, [m, n], [stride_4: int32, stride_5: int32], type="auto"),
                 A0: Buffer(A0_2: Pointer(float32), float32, [m, n], [stride_6: int32, stride_7: int32], type="auto")}
      buffer_map = {A0_1: A0, A1_1: A1, B.v0_1: B.v0, B.v1_1: B.v1} {
      for (i: int32, 0, m) {
        for (j: int32, 0, n) {
          B.v0_2[((i*stride_2) + (j*stride_3))] = ((float32*)A0_2[((i*stride_6) + (j*stride_7))] + 2f32)
          B.v1_2[((i*stride) + (j*stride_1))] = ((float32*)A1_2[((i*stride_4) + (j*stride_5))]*3f32)
        }
      }
    }

    描述协作输入的约简

    多个输入来表示一些归约算子,这些输入将一起协作,例如argmax。在简化过程中,argmax比较算子的值,保留算子的索引。可以表示te.comm_reducer()如下:

    # x and y are the operands of reduction, both of them is a tuple of index
    # and value.
    def fcombine(x, y):
        lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs
     
     
    # our identity element also need to be a tuple, so `fidentity` accepts
    # two types as inputs.
    def fidentity(t0, t1):
        return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
     
     
    argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
     
    # describe the reduction computation
    m = te.var("m")
    n = te.var("n")
    idx = te.placeholder((m, n), name="idx", dtype="int32")
    val = te.placeholder((m, n), name="val", dtype="int32")
    k = te.reduce_axis((0, n), "k")
    T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
     
    # the generated IR code would be:
    s = te.create_schedule(T0.op)
    print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))

    输出:

    primfn(idx_1: handle, val_1: handle, T.v0_1: handle, T.v1_1: handle) -> ()
      attr = {"global_symbol": "main", "tir.noalias": True}
      buffers = {T.v1: Buffer(T.v1_2: Pointer(int32), int32, [m: int32], [stride: int32], type="auto"),
                 val: Buffer(val_2: Pointer(int32), int32, [m, n: int32], [stride_1: int32, stride_2: int32], type="auto"),
                 T.v0: Buffer(T.v0_2: Pointer(int32), int32, [m], [stride_3: int32], type="auto"),
                 idx: Buffer(idx_2: Pointer(int32), int32, [m, n], [stride_4: int32, stride_5: int32], type="auto")}
      buffer_map = {idx_1: idx, val_1: val, T.v0_1: T.v0, T.v1_1: T.v1} {
      for (i: int32, 0, m) {
        T.v0_2[(i*stride_3)] = -1
        T.v1_2[(i*stride)] = -2147483648
        for (k: int32, 0, n) {
          T.v0_2[(i*stride_3)] = @tir.if_then_else(((int32*)val_2[((i*stride_1) + (k*stride_2))] <= (int32*)T.v1_2[(i*stride)]), (int32*)T.v0_2[(i*stride_3)], (int32*)idx_2[((i*stride_4) + (k*stride_5))], dtype=int32)
          T.v1_2[(i*stride)] = @tir.if_then_else(((int32*)val_2[((i*stride_1) + (k*stride_2))] <= (int32*)T.v1_2[(i*stride)]), (int32*)T.v1_2[(i*stride)], (int32*)val_2[((i*stride_1) + (k*stride_2))], dtype=int32)
        }
      }
    }

    注意

    对于不熟悉归约的人,请参阅“ 定义常规换向归约运算”

    使用元组输入调度操作

    尽管可以通过一次批处理算子获得多个输出,但是就算子而言,只能一起调度。

    n = te.var("n")
    m = te.var("m")
    A0 = te.placeholder((m, n), name="A0")
    B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name="B")
    A1 = te.placeholder((m, n), name="A1")
    C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name="C")
     
    s = te.create_schedule(C.op)
    s[B0].compute_at(s[C], C.op.axis[0])
    # as you can see in the below generated IR code:
    print(tvm.lower(s, [A0, A1, C], simple_mode=True))

    输出:

    primfn(A0_1: handle, A1_1: handle, C_1: handle) -> ()
      attr = {"global_symbol": "main", "tir.noalias": True}
      buffers = {C: Buffer(C_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
                 A1: Buffer(A1_2: Pointer(float32), float32, [m, n], [stride_2: int32, stride_3: int32], type="auto"),
                 A0: Buffer(A0_2: Pointer(float32), float32, [m, n], [stride_4: int32, stride_5: int32], type="auto")}
      buffer_map = {A0_1: A0, A1_1: A1, C_1: C} {
      attr [B.v0: Pointer(float32)] "storage_scope" = "global";
      allocate(B.v0, float32, [n]);
      attr [B.v1: Pointer(float32)] "storage_scope" = "global";
      allocate(B.v1, float32, [n]);
      for (i: int32, 0, m) {
        for (j: int32, 0, n) {
          B.v0[j] = ((float32*)A0_2[((i*stride_4) + (j*stride_5))] + 2f32)
          B.v1[j] = ((float32*)A0_2[((i*stride_4) + (j*stride_5))]*3f32)
        }
        for (j_1: int32, 0, n) {
          C_2[((i*stride) + (j_1*stride_1))] = ((float32*)A1_2[((i*stride_2) + (j_1*stride_3))] + (float32*)B.v0[j_1])
        }
      }
    }

    概要

    本文介绍了元组输入操作的用法。

    • 描述正常的批量计算。
    • 描述元组输入的归约运算。
    • 只能根据运算而不是张量来调度计算。

     

    人工智能芯片与自动驾驶
  • 相关阅读:
    [Java] 理解JVM之二:类加载步骤及内存分配
    [Java] 理解JVM之三:垃圾回收机制
    [Java] I/O底层原理之一:字符流、字节流及其源码分析
    [Java] 集合框架原理之二:锁、原子更新、线程池及并发集合
    [Web] Web请求过程之一:HTTP
    [Web] Web请求过程之二:DNS 域名解析
    [Java] 集合框架原理之一:基本结构与源码分析
    Hello World
    c++数据结构图论创建一个导航图,实现基本功能
    jquery1.9: unrecognized expression: a[@href=
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14186200.html
Copyright © 2011-2022 走看看