zoukankan      html  css  js  c++  java
  • 如何在CPU上优化GEMM(下)

    如何在CPU上优化GEMM(下)

    Array Packing

    另一个重要的技巧是数组打包。这个技巧是对数组的存储维度进行重新排序,将某个维度上的连续访问模式在平滑后转换为顺序模式。

     

    如上图所示,在阻塞计算之后,可以观察到B的数组访问模式(扁平化后),它是规则的但不连续的。期望经过一些转换,可以得到连续访问模式。可以将[16][16]数组重新排序为[16/4][16][4]数组,这样当从压缩数组中获取相应的值时,B的访问模式将是顺序的。

    # We have to re-write the algorithm slightly.

    packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB")

    C = te.compute(

        (M, N),

        lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),

        name="C",

    )

     

    s = te.create_schedule(C.op)

     

    xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

    (k,) = s[C].op.reduce_axis

    ko, ki = s[C].split(k, factor=4)

     

    s[C].reorder(xo, yo, ko, xi, ki, yi)

    s[C].vectorize(yi)

     

    x, y, z = s[packedB].op.axis

    s[packedB].vectorize(z)

    s[packedB].parallel(x)

     

    func = tvm.build(s, [A, B, C], target=target, name="mmult")

    assert func

     

    c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)

    func(a, b, c)

    tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)

     

    evaluator = func.time_evaluator(func.entry_name, ctx, number=10)

    print("Opt4: %f" % evaluator(a, b, c).mean)

    Out:

    Opt4: 0.105409

    Here is the generated IR after array packing.

    print(tvm.lower(s, [A, B, C], simple_mode=True))

    Out:

    primfn(A_1: handle, B_1: handle, C_1: handle) -> ()

      attr = {"global_symbol": "main", "tir.noalias": True}

      buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),

                 B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], []),

                 A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}

      buffer_map = {A_1: A, B_1: B, C_1: C} {

      attr [packedB: Pointer(float32)] "storage_scope" = "global";

      allocate(packedB, float32x32, [32768]) {

        for (x: int32, 0, 32) "parallel" {

          for (y: int32, 0, 1024) {

            packedB[ramp(((x*32768) + (y*32)), 1, 32)] = (float32x32*)B_2[ramp(((y*1024) + (x*32)), 1, 32)]

          }

        }

        for (x.outer: int32, 0, 32) {

          for (y.outer: int32, 0, 32) {

            for (x.inner.init: int32, 0, 32) {

              C_2[ramp((((x.outer*32768) + (x.inner.init*1024)) + (y.outer*32)), 1, 32)] = broadcast(0f32, 32)

            }

            for (k.outer: int32, 0, 256) {

              for (x.inner: int32, 0, 32) {

                for (k.inner: int32, 0, 4) {

                  C_2[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] = ((float32x32*)C_2[ramp((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.inner*1024)) + (k.outer*4)) + k.inner)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + (k.inner*32)), 1, 32)]))

                }

              }

            }

          }

        }

      }

    }

    Write cache for blocks

    分块后,程序将结果逐块写入C,访问模式不是顺序的。因此,可以使用一个顺序缓存数组来保存块结果,并在所有块结果就绪时写入C。

    s = te.create_schedule(C.op)

     

    # Allocate write cache

    CC = s.cache_write(C, "global")

     

    xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

     

    # Write cache is computed at yo

    s[CC].compute_at(s[C], yo)

     

    # New inner axes

    xc, yc = s[CC].op.axis

     

    (k,) = s[CC].op.reduce_axis

    ko, ki = s[CC].split(k, factor=4)

    s[CC].reorder(ko, xc, ki, yc)

    s[CC].unroll(ki)

    s[CC].vectorize(yc)

     

    x, y, z = s[packedB].op.axis

    s[packedB].vectorize(z)

    s[packedB].parallel(x)

     

    func = tvm.build(s, [A, B, C], target=target, name="mmult")

    assert func

     

    c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)

    func(a, b, c)

    tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)

     

    evaluator = func.time_evaluator(func.entry_name, ctx, number=10)

    print("Opt5: %f" % evaluator(a, b, c).mean)

    Out:

    Opt5: 0.098048

    Here is the generated IR after blocking.

    print(tvm.lower(s, [A, B, C], simple_mode=True))

    Out:

    primfn(A_1: handle, B_1: handle, C_1: handle) -> ()

      attr = {"global_symbol": "main", "tir.noalias": True}

      buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),

                 B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], []),

                 A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}

      buffer_map = {A_1: A, B_1: B, C_1: C} {

      attr [packedB: Pointer(float32)] "storage_scope" = "global";

      allocate(packedB, float32x32, [32768]);

      attr [C.global: Pointer(float32)] "storage_scope" = "global";

      allocate(C.global, float32, [1024]) {

        for (x: int32, 0, 32) "parallel" {

          for (y: int32, 0, 1024) {

            packedB[ramp(((x*32768) + (y*32)), 1, 32)] = (float32x32*)B_2[ramp(((y*1024) + (x*32)), 1, 32)]

          }

        }

        for (x.outer: int32, 0, 32) {

          for (y.outer: int32, 0, 32) {

            for (x.c.init: int32, 0, 32) {

              C.global[ramp((x.c.init*32), 1, 32)] = broadcast(0f32, 32)

            }

            for (k.outer: int32, 0, 256) {

              for (x.c: int32, 0, 32) {

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))], 32)*(float32x32*)packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)]))

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)]))

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)]))

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)]))

              }

            }

            for (x.inner: int32, 0, 32) {

              for (y.inner: int32, 0, 32) {

                C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (float32*)C.global[((x.inner*32) + y.inner)]

              }

            }

          }

        }

      }

    }

    Parallel

    此外,还可以利用多核处理器来实现线程级的并行化。

    s = te.create_schedule(C.op)

     

    CC = s.cache_write(C, "global")

     

    xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)

     

    s[CC].compute_at(s[C], yo)

     

    xc, yc = s[CC].op.axis

     

    (k,) = s[CC].op.reduce_axis

    ko, ki = s[CC].split(k, factor=4)

    s[CC].reorder(ko, xc, ki, yc)

    s[CC].unroll(ki)

    s[CC].vectorize(yc)

     

    # parallel

    s[C].parallel(xo)

     

    x, y, z = s[packedB].op.axis

    s[packedB].vectorize(z)

    s[packedB].parallel(x)

     

    func = tvm.build(s, [A, B, C], target=target, name="mmult")

    assert func

     

    c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)

    func(a, b, c)

    tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)

     

    evaluator = func.time_evaluator(func.entry_name, ctx, number=50)

    opt6_time = evaluator(a, b, c).mean

    print("Opt6: %f" % opt6_time)

    Out:

    Opt6: 0.032347

    Here is the generated IR after parallelization.

    print(tvm.lower(s, [A, B, C], simple_mode=True))

    Out:

    primfn(A_1: handle, B_1: handle, C_1: handle) -> ()

      attr = {"global_symbol": "main", "tir.noalias": True}

      buffers = {C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),

                 B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], []),

                 A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], [])}

      buffer_map = {A_1: A, B_1: B, C_1: C} {

      attr [packedB: Pointer(float32)] "storage_scope" = "global";

      allocate(packedB, float32x32, [32768]) {

        for (x: int32, 0, 32) "parallel" {

          for (y: int32, 0, 1024) {

            packedB[ramp(((x*32768) + (y*32)), 1, 32)] = (float32x32*)B_2[ramp(((y*1024) + (x*32)), 1, 32)]

          }

        }

        for (x.outer: int32, 0, 32) "parallel" {

          attr [C.global: Pointer(float32)] "storage_scope" = "global";

          allocate(C.global, float32, [1024]);

          for (y.outer: int32, 0, 32) {

            for (x.c.init: int32, 0, 32) {

              C.global[ramp((x.c.init*32), 1, 32)] = broadcast(0f32, 32)

            }

            for (k.outer: int32, 0, 256) {

              for (x.c: int32, 0, 32) {

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))], 32)*(float32x32*)packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)]))

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)]))

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)]))

                C.global[ramp((x.c*32), 1, 32)] = ((float32x32*)C.global[ramp((x.c*32), 1, 32)] + (broadcast((float32*)A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)], 32)*(float32x32*)packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)]))

              }

            }

            for (x.inner: int32, 0, 32) {

              for (y.inner: int32, 0, 32) {

                C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = (float32*)C.global[((x.inner*32) + y.inner)]

              }

            }

          }

        }

      }

    }

    Summary

    在用18行代码应用上述简单的优化之后,生成的代码可以达到MKL的60%的numpy性能。请注意,网页上的输出反映了非独占Docker容器上的运行时间,因此是不可靠的。强烈建议自己来完成,以观察TVM所获得的性能提升。

    https://tvm.apache.org/docs/tutorials/optimize/opt_gemm.html#sphx-glr-tutorials-optimize-opt-gemm-py

  • 相关阅读:
    栈和堆的区别【个人总结】
    理解堆与栈
    javacript属性
    Reapeater CommandName ,CommandArgument
    FormsAuthentication.HashPasswordForStoringInConfigFile(str1, str2);
    文件上传处理
    GetJson
    js内置对象
    Debug
    [转]关于一些SPFA的标程
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14109537.html
Copyright © 2011-2022 走看看