TVM教程:深入理解TensorIR抽象核心原理

2026-06-11阅读 0热度 0
Tensor

TVM 0.21.0 版本已经更新,对应的中文文档也同步对齐。今天来聊聊 TensorIR——这个 TVM 中的张量程序抽象,算是标准的机器学习编译框架之一。它的核心目标,就是描述循环以及相关的硬件加速选项,比如线程、专用硬件指令和内存访问。

【TVM教程】理解 TensorIR 抽象

为了更直观地理解,咱们先看一个具体的计算序列作为例子:假设有两个 128×128 的矩阵 AB,我们执行以下两步张量计算。

Y_{i, j} &= sum_k A_{i, k} times B_{k, j} 
C_{i, j} &= mathbb{relu}(Y_{i, j}) = mathbb{max}(Y_{i, j}, 0)

这个操作其实很常见——神经网络中的线性层加 ReLU 激活函数就是这种模式。下面我们用 TensorIR 来表达它。不过,在正式接触 TensorIR 之前,先来段原生 Python 代码做个热身:

def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
    for i in range(128):
        for j in range(128):
            C[i, j] = max(Y[i, j], 0)

理解了 NumPy 版本的实现后,现在来看 TensorIR 的写法。下面的代码块展示了 mm_relu 的 TVMScript 实现——这是一种嵌入在 Python AST 中的领域特定语言。

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                vk = T.axis.reduce(128, k)
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

接下来,逐点拆解一下这个 TensorIR 程序中的关键元素。

函数参数与缓冲区

函数参数和 NumPy 版本中的参数一一对应。

# TensorIR
def mm_relu(A: T.Buffer((128, 128), "float32"),
            B: T.Buffer((128, 128), "float32"),
            C: T.Buffer((128, 128), "float32")):
    ...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    ...

这里的 ABC 用的是 T.Buffer 类型,明确标注了形状 (128, 128) 和数据类型 float32。多出来的这些信息,正好帮助 MLC 编译器生成专注于特定形状和数据类型的优化代码。

中间结果的分配也类似:

# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")

循环迭代

循环结构上也有直接的对应关系。 T.grid 是 TensorIR 的语法糖,能一口气写出多层嵌套的迭代器:

# TensorIR 用 `T.grid`
for i, j, k in T.grid(128, 128, 128):
    ...
# TensorIR 用 `range`
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...
# NumPy
for i in range(128):
    for j in range(128):
        for k in range(128):
            ...

计算块

最核心的区别出现在计算语句上:TensorIR 引入了一个叫 T.block 的构造。先看代码对比:

# TensorIR
with T.block("Y"):
    vi = T.axis.spatial(128, i)
    vj = T.axis.spatial(128, j)
    vk = T.axis.reduce(128, k)
    with T.init():
        Y[vi, vj] = T.float32(0)
    Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
    Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]

一个块代表 TensorIR 中基本的计算单元。关键点在于,块携带的信息比 NumPy 代码要多——它明确了一组块轴 vi, vj, vk,以及围绕这些轴的计算约束。

vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)

这三行声明的格式可以抽象为:

[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])

它们至少表达了三点:

  • 指定了 vivjvk(本例中绑定到 ijk)。
  • 声明了这些轴的原生范围(比如 T.axis.spatial(128, i) 中的 128)。
  • 标明了迭代器的属性——是空间轴还是归约轴。

块轴属性

再深入一点看块轴属性的含义。这些属性描述了轴与当前计算的关系。在这个块里,三个轴 vivjvk 分别读取 A[vi, vk]B[vk, vj],并写入 Y[vi, vj]。这里对 Y 做的是(归约)更新,我们暂且叫它写入,因为不需要从其他地方拿 Y 的值。

重点是:对于固定的 vivj,这个计算块产生的是 Y 中一个空间位置的点值(Y[vi, vj]),它独立于 Y 中其他位置(不同 vivj 的值)。所以 vivj 被称作空间轴——它们直接对应块写入缓冲区区域的起始位置。参与归约的轴(vk)自然就是归约轴。

为什么在块中有额外的信息

这额外信息(块轴的范围和属性)让块在执行迭代时完全自包含,不受外部 i, j, k 循环的干扰。换句话说,你可以独立检查每个块是否正确。

此外,块轴信息还提供了一种校验机制——帮助验证外部循环与计算是否匹配。例如下面这段代码就会报错,因为循环期望大小为 128 的迭代器,但实际只绑定了 127:

# 错误的程序,由于循环和块迭代不匹配
for i in range(127):
    with T.block("C"):
        vi = T.axis.spatial(128, i)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
        这里出现错误是因为迭代器大小不匹配
        ...

块轴绑定的语法

如果每个块轴都直接映射到外部循环迭代器,可以用 T.axis.remap 在单行内完成声明:

# SSR 表示每个轴的属性分别为 "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])

这等价于:

vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)

因此,我们可以用更简洁的方式重写上面的程序:

@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
    @T.prim_func
    def mm_relu(A: T.Buffer((128, 128), "float32"),
                B: T.Buffer((128, 128), "float32"),
                C: T.Buffer((128, 128), "float32")):
        Y = T.alloc_buffer((128, 128), dtype="float32")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))

这样,代码读起来就清爽多了。TensorIR 这种抽象方式,本质上是在传统循环嵌套的基础上,显式地注入了计算结构和硬件相关的语义——这正是机器学习编译想要自动优化代码的关键前提。

免责声明

本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。

相关阅读

更多
欢迎回来 登录或注册后,可保存提示词和历史记录
登录后可同步收藏、历史记录和常用模板
注册即表示同意服务条款与隐私政策