TVM转换教程:精选工具与实战指南

2026-06-11阅读 0热度 0
深度学习

TVM 已经更新到了 0.21.0 版本,这次的中文文档也同步对齐了最新的 API 和用法。不过这次我们不聊版本差异,回到核心:深入编译流程中的关键一环——原始张量函数的转换(Transformation)。

【TVM教程】转换

Apache TVM 是一套面向 CPU、GPU 以及各种机器学习加速芯片的深度学习编译框架,架构很深层,但用法可以很优雅。

之前一节已经展示过如何用 TensorIR 写出 mm_relu。实际场景中,同一个功能往往有多种实现路径,而不同的路径在性能上天差地别。先回顾一下上一节的实现:

import tvm
from tvm.script import ir as I
from tvm.script import tir as T


@I.ir_module
class MyModule:
    @T.prim_func
    def main(
        A: T.Buffer((128, 128), "float32"),
        B: T.Buffer((128, 128), "float32"),
        C: T.Buffer((128, 128), "float32"),
    ):
        T.func_attr({"tir.noalias": True})
        Y = T.alloc_buffer((128, 128))
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

动手之前,先给原始实现测个速:

import numpy as np

a_np = np.random.uniform(size=(128, 128)).astype("float32")
b_np = np.random.uniform(size=(128, 128)).astype("float32")
c_np = a_np @ b_np

a_nd = tvm.runtime.tensor(a_np)
b_nd = tvm.runtime.tensor(b_np)
c_nd = tvm.runtime.tensor(np.zeros((128, 128), dtype="float32"))


def evaluate(mod: tvm.IRModule):
    lib = tvm.tir.build(mod, target="llvm")
    # 检查正确性
    lib(a_nd, b_nd, c_nd)
    np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5)
    # 评估性能
    f_timer = lib.time_evaluator("main", tvm.cpu())
    print(f_timer(a_nd, b_nd, c_nd))


evaluate(MyModule)

输出:

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
   2.7253       2.7253       2.7253       2.7253       0.0000

初始化计划

转换的第一步,创建一个 Schedule 辅助类,把 MyModule 喂进去:

sch = tvm.tir.Schedule(MyModule)

循环分块(Loop Tiling)

拿到 Y 块及其相关的循环引用:

block_Y = sch.get_block("Y")
i, j, k = sch.get_loops(block_Y)

接下来动刀子:把循环 j 切成两段,内层循环长度设为 8。注意,转换是分布执行的,要是手滑执行两遍,会因为 j 变量不存在而报错:

j0, j1 = sch.split(j, factors=[None, 8])

看看 sch.mod 里的变形结果:

sch.mod.show()

输出:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0, j_1, k in T.grid(128, 16, 8, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j_0 * 8 + j_1)
                vk = T.axis.reduce(128, k)
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(Y[vi, vj])
                with T.init():
                    Y[vi, vj] = T.float32(0.0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

切完后得到 j_0(范围 16)和 j_1(范围 8)。然后重排一下顺序:

sch.reorder(j0, k, j1)
sch.mod.show()
evaluate(sch.mod)

输出:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0, k, j_1 in T.grid(128, 16, 128, 8):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j_0 * 8 + j_1)
                vk = T.axis.reduce(128, k)
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(Y[vi, vj])
                with T.init():
                    Y[vi, vj] = T.float32(0.0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(Y[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
   0.8571       0.8571       0.8571       0.8571       0.0000

看到没?从 2.7ms 直接掉到 0.85ms,这就是循环变换的威力。

利用局部性

接下来再搞两个新花样,生成另一个变体。

先用原语 reverse_compute_at,把块 C 塞进块 Y 的某个内部循环:

block_C = sch.get_block("C")
sch.reverse_compute_at(block_C, j0)
sch.mod.show()

输出:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0 in T.grid(128, 16):
            for k, j_1 in T.grid(128, 8):
                with T.block("Y"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0.0)
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in range(8):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

重写归约操作

到目前为止,归约的初始化和更新步骤还是挤在同一个块体里的。这种写法做循环变换很方便——初始化和更新需要保持外层循环(如 ij)同步。但前面的变换已经跑完了,现在可以用 decompose_reduction 把 Y 的初始化和归约更新拆开:

sch.decompose_reduction(block_Y, k)
sch.mod.show()
evaluate(sch.mod)

输出:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0 in T.grid(128, 16):
            for j_1_init in range(8):
                with T.block("Y_init"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                    T.reads()
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = T.float32(0.0)
            for k, j_1 in T.grid(128, 8):
                with T.block("Y_update"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in range(8):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
   0.3377       0.3377       0.3377       0.3377       0.0000

再次提速,从 0.85ms 压到 0.34ms。拆开归约这一步很关键。

追踪转换

TensorIR 的调度是过程化的“语言”,每一步转换都是按顺序执行的。想看看刚才到底干了什么?可以用 sch.trace 打印调度历史:

sch.trace.show()

输出:

# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="Y", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l4, l3, l5)
  b6 = sch.get_block(name="C", func_name="main")
  sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
  b7 = sch.decompose_reduction(block=b0, loop=l3)

当然,也可以把 IRModule 和调度历史一起打出来:

sch.show()

输出:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        Y = T.alloc_buffer((128, 128))
        for i, j_0 in T.grid(128, 16):
            for j_1_init in range(8):
                with T.block("Y_init"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                    T.reads()
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = T.float32(0.0)
            for k, j_1 in T.grid(128, 8):
                with T.block("Y_update"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(Y[vi, vj], A[vi, vk], B[vk, vj])
                    T.writes(Y[vi, vj])
                    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
            for ax0 in range(8):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 8 + ax0)
                    T.reads(Y[vi, vj])
                    T.writes(C[vi, vj])
                    C[vi, vj] = T.max(Y[vi, vj], T.float32(0.0))

# from tvm import tir
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="Y", func_name="main")
  l1, l2, l3 = sch.get_loops(block=b0)
  l4, l5 = sch.split(loop=l2, factors=[None, 8], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l4, l3, l5)
  b6 = sch.get_block(name="C", func_name="main")
  sch.reverse_compute_at(block=b6, loop=l4, preserve_unit_loops=False, index=-1)
  b7 = sch.decompose_reduction(block=b0, loop=l3)
免责声明

本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。

相关阅读

更多
欢迎回来 登录或注册后,可保存提示词和历史记录
登录后可同步收藏、历史记录和常用模板
注册即表示同意服务条款与隐私政策