如何使用TensorCores优化卷积
本文将演示如何在TVM中使用TensorCores编写高性能的卷积计划。假设卷积的输入有大量数据。首先介绍如何在GPU上优化卷积。
TensorCore简介
每个Tensor核心都提供一个4x4x4的矩阵处理阵列,该阵列可以运行 ,其中A,B,C和D是4x4矩阵,如图所示。矩阵乘法输入A和B是FP16矩阵,而累加矩阵C和D可以是FP16或FP32矩阵。D
=
A
*
B
+
C
但是,CUDA程序员只能使用扭曲级原语,在张量核上执行16x16x16半精度矩阵乘法。在调用矩阵乘法之前,程序员必须将内存中的数据显式地加载到寄存器中。NVCC编译器将该原语转换为多个内存加载指令。在运行时runtime,每个线程从矩阵A加载16个元素,从矩阵B加载16个元素。wmma::mma_sync(acc_frag,
a_frag,
b_frag,
acc_frag)wmma::load_matrix_sync
准备和算法
将固定大小用于256通道和14 x 14尺寸的输入张量。批处理大小为256。卷积过滤器包含512个大小为3 x 3的过滤器。对于卷积,使用步幅大小1和填充大小1。在示例中,使用NHWCnc内存布局。以下代码定义了TVM中的卷积算法。
import tvm
from tvm import te
import numpy as np
from tvm.contrib import nvcc
# The sizes of inputs and filters
batch_size = 256
height = 14
width = 14
in_channels = 256
out_channels = 512
kernel_h = 3
kernel_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
# TensorCore shape
block_size = 16
assert batch_size % block_size == 0
assert in_channels % block_size == 0
assert out_channels % block_size == 0
# Input feature map: (N, H, W, IC, n, ic)
data_shape = (
batch_size // block_size,
height,
width,
in_channels // block_size,
block_size,
block_size,
)
# Kernel: (H, W, IC, OC, ic, oc)
kernel_shape = (
kernel_h,
kernel_w,
in_channels // block_size,
out_channels // block_size,
block_size,
block_size,
)
# Output feature map: (N, H, W, OC, n, oc)
output_shape = (
batch_size // block_size,
height,
width,
out_channels // block_size,
block_size,
block_size,
)
# Reduction axes
kh = te.reduce_axis((0, kernel_h), name="kh")
kw = te.reduce_axis((0, kernel_w), name="kw")
ic = te.reduce_axis((0, in_channels // block_size), name="ic")
ii = te.reduce_axis((0, block_size), name="ii")
# Algorithm
A = te.placeholder(data_shape, name="A", dtype="float16")
W = te.placeholder(kernel_shape, name="W", dtype="float16")
Apad = te.compute(
(
batch_size // block_size,
height + 2 * pad_h,
width + 2 * pad_w,
in_channels // block_size,
block_size,
block_size,
),
lambda n, h, w, i, nn, ii: tvm.tir.if_then_else(
tvm.tir.all(h >= pad_h, h - pad_h < height, w >= pad_w, w - pad_w < width),
A[n, h - pad_h, w - pad_w, i, nn, ii],
tvm.tir.const(0.0, "float16"),
),
name="Apad",
)
Conv = te.compute(
output_shape,
lambda n, h, w, o, nn, oo: te.sum(
Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32")
* W[kh, kw, ic, o, ii, oo].astype("float32"),
axis=[ic, kh, kw, ii],
),
name="Conv",
)
s = te.create_schedule(Conv.op)
s[Apad].compute_inline()
存储范围
在传统的GPU计划中,具有全局,共享和本地内存范围。为了支持TensorCores,添加了另外三个特殊的存储范围:wmma.matrix_a, wmma.matrix_b和wmma.accumulator。在硬件上,所有片段作用域存储在片上寄存器级别,与本地存储器位于同一位置。
# Designate the memory hierarchy
AS = s.cache_read(Apad, "shared", [Conv])
WS = s.cache_read(W, "shared", [Conv])
AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
ConvF = s.cache_write(Conv, "wmma.accumulator")
定义张量特征
实际上,TensorCore是一种特殊的硬件操作。因此,可以使用Tensorize用TensorCore指令替换计算单位。首先,需要定义张量特征。
有四种基本的操作TensorCore: ,, 。由于都用于矩阵乘法,因此可以编写以下三个内部函数。fill_fragmentload_matrixmma_syncstore_matrixfill_fragmentmma_sync
def intrin_wmma_load_matrix(scope):
n = 16
A = te.placeholder((n, n), name="A", dtype="float16")
BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="shared", data_alignment=32, offset_factor=256)
C = te.compute((n, n), lambda i, j: A[i, j], name="C")
BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(
tvm.tir.call_intrin(
"handle",
"tir.tvm_load_matrix_sync",
BC.data,
n,
n,
n,
BC.elem_offset // 256,
BA.access_ptr("r"),
n,
"row_major",
)
)
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm():
n = 16
A = te.placeholder((n, n), name="A", dtype="float16")
B = te.placeholder((n, n), name="B", dtype="float16")
k = te.reduce_axis((0, n), name="k")
C = te.compute(
(n, n),
lambda ii, jj: te.sum(A[ii, k].astype("float") * B[k, jj].astype("float"), axis=k),
name="C",
)
BA = tvm.tir.decl_buffer(
A.shape, A.dtype, name="BA", scope="wmma.matrix_a", data_alignment=32, offset_factor=256
)
BB = tvm.tir.decl_buffer(
B.shape, B.dtype, name="BB", scope="wmma.matrix_b", data_alignment=32, offset_factor=256
)
BC = tvm.tir.decl_buffer(
C.shape, C.dtype, name="BC", scope="wmma.accumulator", data_alignment=32, offset_factor=256
)
def intrin_func(ins, outs):
BA, BB = ins
(BC,) = outs
def init():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_intrin(
"handle", "tir.tvm_fill_fragment", BC.data, n, n, n, BC.elem_offset // 256, 0.0
)
)
return ib.get()
def update():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_intrin(
"handle",
"tir.tvm_mma_sync",
BC.data,
BC.elem_offset // 256,
BA.data,
BA.elem_offset // 256,
BB.data,
BB.elem_offset // 256,
BC.data,
BC.elem_offset // 256,
)
)
return ib.get()
return update(), init(), update()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_wmma_store_matrix():
n = 16
A = te.placeholder((n, n), name="A", dtype="float32")
BA = tvm.tir.decl_buffer(
A.shape, A.dtype, scope="wmma.accumulator", data_alignment=32, offset_factor=256
)
C = te.compute((n, n), lambda i, j: A[i, j], name="C")
BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", data_alignment=32, offset_factor=256)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
ib.emit(
tvm.tir.call_intrin(
"handle",
"tir.tvm_store_matrix_sync",
BA.data,
n,
n,
n,
BA.elem_offset // 256,
BC.access_ptr("w"),
n,
"row_major",
)
)
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
调度计算
要在TVM中使用TensorCores,必须将计算调度到特定的结构中以匹配张量特征。与传统的GPU程序一样,可以使用共享内存来提高速度。如果对阻塞和共享内存有任何疑问,请参阅如何在GPU上优化卷积。
在此示例中,每个块包含2x4变形,并且每个变形调用4x2 TensorCore指令。因此,每个warp的输出形状为64x32,每个块输出128x128标题。由于共享内存空间的限制,一次只能加载2个块(2x128x128个图块)。
warp操作
请注意,所有TensorCore指令均为warp级指令,这意味着warp中的所有32个线程应同时执行此指令。使theadIdx.x范围= 32是解决此问题的最简单方法之一。然后可以将threadIdx.x绑定到任何循环,除了那些直接或间接包含TensorCore内部函数的循环。还要注意,这不是唯一的解决方案。唯一要做的是确保warp中的所有线程可以同时调用TensorCore。
# Define tiling sizes
block_row_warps = 4
block_col_warps = 2
warp_row_tiles = 2
warp_col_tiles = 4
warp_size = 32
chunk = 2
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
thread_z = te.thread_axis("threadIdx.z")
nc, hc, wc, oc, nnc, ooc = Conv.op.axis
block_k = s[Conv].fuse(hc, wc)
s[Conv].bind(block_k, block_z)
nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
block_i, nc = s[Conv].split(nc, factor=block_row_warps)
oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
block_j, oc = s[Conv].split(oc, factor=block_col_warps)
s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
s[Conv].bind(block_i, block_x)
s[Conv].bind(block_j, block_y)
s[Conv].bind(nc, thread_y)
s[Conv].bind(oc, thread_z)
# Schedule local computation
s[ConvF].compute_at(s[Conv], oc)
n, h, w, o, nnf, oof = ConvF.op.axis
ko, ki = s[ConvF].split(ic, factor=chunk)
s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
# Move intermediate computation into each output compute tile
s[AF].compute_at(s[ConvF], kw)
s[WF].compute_at(s[ConvF], kw)
# Schedule for A's share memory
s[AS].compute_at(s[ConvF], kh)
n, h, w, i, nn, ii = AS.op.axis
tx, xo = s[AS].split(n, nparts=block_row_warps)
ty, yo = s[AS].split(xo, nparts=block_col_warps)
t = s[AS].fuse(nn, ii)
to, ti = s[AS].split(t, factor=warp_size)
s[AS].bind(tx, thread_y)
s[AS].bind(ty, thread_z)
s[AS].bind(ti, thread_x)
# Schedule for W's share memory
s[WS].compute_at(s[ConvF], kh)
kh, kw, ic, o, ii, oo = WS.op.axis
tx, xo = s[WS].split(o, nparts=block_row_warps)
ty, yo = s[WS].split(xo, nparts=block_col_warps)
t = s[WS].fuse(ii, oo)
to, ti = s[WS].split(t, nparts=warp_size)
s[WS].bind(tx, thread_y)
s[WS].bind(ty, thread_z)
s[WS].bind(to, thread_x)
s[WS].vectorize(ti)
print(tvm.lower(s, [A, W, Conv], simple_mode=True))
输出:
primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {Conv: Buffer(Conv_2: Pointer(float32), float32, [16, 14, 14, 32, 16, 16], []),
W: Buffer(W_2: Pointer(float16), float16, [3, 3, 16, 32, 16, 16], []),
A: Buffer(A_2: Pointer(float16), float16, [16, 14, 14, 16, 16, 16], [])}
buffer_map = {A_1: A, W_1: W, Conv_1: Conv} {
attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
attr [Conv.wmma.accumulator: Pointer(float32)] "storage_scope" = "wmma.accumulator";
allocate(Conv.wmma.accumulator, float32, [2048]);
attr [Apad.shared: Pointer(float16)] "storage_scope" = "shared";
allocate(Apad.shared, float16, [12288]);
attr [W.shared: Pointer(float16)] "storage_scope" = "shared";
allocate(W.shared, float16, [12288]);
attr [Apad.shared.wmma.matrix_a: Pointer(float16)] "storage_scope" = "wmma.matrix_a";
allocate(Apad.shared.wmma.matrix_a, float16, [512]);
attr [W.shared.wmma.matrix_b: Pointer(float16)] "storage_scope" = "wmma.matrix_b";
allocate(W.shared.wmma.matrix_b, float16, [1024]);
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
for (n.c.init: int32, 0, 2) {
for (o.c.init: int32, 0, 4) {
for (nn.c.init: int32, 0, 16) {
for (oo.c.init: int32, 0, 16) {
Conv.wmma.accumulator[((((n.c.init*1024) + (o.c.init*256)) + (nn.c.init*16)) + oo.c.init)] = 0f32
}
}
}
}
for (ic.outer: int32, 0, 8) {
for (kh: int32, 0, 3) {
for (ax2: int32, 0, 3) {
for (ax3: int32, 0, 2) {
for (ax4.ax5.fused.outer: int32, 0, 8) {
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = @tir.if_then_else(((((1 <= (floordiv(blockIdx.z, 14) + kh)) && ((floordiv(blockIdx.z, 14) + kh) < 15)) && (1 <= (ax2 + floormod(blockIdx.z, 14)))) && ((ax2 + floormod(blockIdx.z, 14)) < 15)), (float16*)A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)], 0f16, dtype=float16)
}
}
}
for (ax1: int32, 0, 3) {
for (ax2_1: int32, 0, 2) {
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = (float16x8*)W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)]
}
}
for (ic.inner: int32, 0, 2) {
for (kw: int32, 0, 3) {
for (ax0: int32, 0, 2) {
for (ax4: int32, 0, 16) {
for (ax5: int32, 0, 16) {
Apad.shared.wmma.matrix_a[(((ax0*256) + (ax4*16)) + ax5)] = (float16*)Apad.shared[((((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)) + (ax4*16)) + ax5)]
}
}
}
for (ax3_1: int32, 0, 4) {
for (ax4_1: int32, 0, 16) {
for (ax5_1: int32, 0, 16) {
W.shared.wmma.matrix_b[(((ax3_1*256) + (ax4_1*16)) + ax5_1)] = (float16*)W.shared[((((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)) + (ax4_1*16)) + ax5_1)]
}
}
}
for (n.c: int32, 0, 2) {
for (o.c: int32, 0, 4) {
for (nn.c: int32, 0, 16) {
for (oo.c: int32, 0, 16) {
for (ii: int32, 0, 16) {
Conv.wmma.accumulator[((((n.c*1024) + (o.c*256)) + (nn.c*16)) + oo.c)] = ((float32*)Conv.wmma.accumulator[((((n.c*1024) + (o.c*256)) + (nn.c*16)) + oo.c)] + (cast(float32, (float16*)Apad.shared.wmma.matrix_a[(((n.c*256) + (nn.c*16)) + ii)])*cast(float32, (float16*)W.shared.wmma.matrix_b[(((o.c*256) + (ii*16)) + oo.c)])))
}
}
}
}
}
}
}
}
}
for (n.inner: int32, 0, 2) {
for (o.inner: int32, 0, 4) {
for (nn: int32, 0, 16) {
for (oo: int32, 0, 16) {
Conv_2[(((((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)) + (nn*16)) + oo)] = (float32*)Conv.wmma.accumulator[((((n.inner*1024) + (o.inner*256)) + (nn*16)) + oo)]
}
}
}
}
}
}
降低算力
最后一个阶段是通过将2D卷积映射到张量特征,来将计算循环降低到TensorCore硬件特征
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix("wmma.matrix_a"))
s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix("wmma.matrix_b"))
s[Conv].tensorize(nnc, intrin_wmma_store_matrix())
s[ConvF].tensorize(nnf, intrin_wmma_gemm())
print(tvm.lower(s, [A, W, Conv], simple_mode=True))
输出:
primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {Conv: Buffer(Conv_2: Pointer(float32), float32, [16, 14, 14, 32, 16, 16], []),
W: Buffer(W_2: Pointer(float16), float16, [3, 3, 16, 32, 16, 16], []),
A: Buffer(A_2: Pointer(float16), float16, [16, 14, 14, 16, 16, 16], [])}
buffer_map = {A_1: A, W_1: W, Conv_1: Conv} {
attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
attr [Conv.wmma.accumulator: Pointer(float32)] "storage_scope" = "wmma.accumulator";
allocate(Conv.wmma.accumulator, float32, [2048]);
attr [Apad.shared: Pointer(float16)] "storage_scope" = "shared";
allocate(Apad.shared, float16, [12288]);
attr [W.shared: Pointer(float16)] "storage_scope" = "shared";
allocate(W.shared, float16, [12288]);
attr [Apad.shared.wmma.matrix_a: Pointer(float16)] "storage_scope" = "wmma.matrix_a";
allocate(Apad.shared.wmma.matrix_a, float16, [512]);
attr [W.shared.wmma.matrix_b: Pointer(float16)] "storage_scope" = "wmma.matrix_b";
allocate(W.shared.wmma.matrix_b, float16, [1024]);
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
for (n.c.init: int32, 0, 2) {
for (o.c.init: int32, 0, 4) {
@tir.tvm_fill_fragment(Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), 0f32, dtype=handle)
}
}
for (ic.outer: int32, 0, 8) {
for (kh: int32, 0, 3) {
for (ax2: int32, 0, 3) {
for (ax3: int32, 0, 2) {
for (ax4.ax5.fused.outer: int32, 0, 8) {
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = @tir.if_then_else(((((1 <= (floordiv(blockIdx.z, 14) + kh)) && ((floordiv(blockIdx.z, 14) + kh) < 15)) && (1 <= (ax2 + floormod(blockIdx.z, 14)))) && ((ax2 + floormod(blockIdx.z, 14)) < 15)), (float16*)A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)], 0f16, dtype=float16)
}
}
}
for (ax1: int32, 0, 3) {
for (ax2_1: int32, 0, 2) {
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = (float16x8*)W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)]
}
}
for (ic.inner: int32, 0, 2) {
for (kw: int32, 0, 3) {
for (ax0: int32, 0, 2) {
@tir.tvm_load_matrix_sync(Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1, dtype=handle), 16, "row_major", dtype=handle)
}
for (ax3_1: int32, 0, 4) {
@tir.tvm_load_matrix_sync(W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1, dtype=handle), 16, "row_major", dtype=handle)
}
for (n.c: int32, 0, 2) {
for (o.c: int32, 0, 4) {
@tir.tvm_mma_sync(Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c), dtype=handle)
}
}
}
}
}
}
for (n.inner: int32, 0, 2) {
for (o.inner: int32, 0, 4) {
@tir.tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2, dtype=handle), 16, "row_major", dtype=handle)
}
}
}
}
生成CUDA内核
最后,使用TVM生成和编译CUDA内核,并评估卷积的延迟。由于TensorCores仅在具有Compute Capability 7.0或更高版本的NVIDIA GPU中受支持,因此它可能无法在构建服务器上运行
ctx = tvm.gpu(0)
if nvcc.have_tensorcore(ctx.compute_version):
with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}):
func = tvm.build(s, [A, W, Conv], "cuda")
a_np = np.random.uniform(size=data_shape).astype(A.dtype)
w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
print("conv2d with tensor core: %f ms" % (evaluator(a, w, c).mean * 1e3))
输出:
conv2d with tensor core: 8.329637 ms
概要
本文演示了如何使用TVM调度原语在特定GPU上调用TensorCore。
https://tvm.apache.org/docs/tutorials/optimize/opt_conv_tensorcore.html