TVM教程:深入理解TensorIR抽象核心原理
TVM 0.21.0 版本已经更新,对应的中文文档也同步对齐。今天来聊聊 TensorIR——这个 TVM 中的张量程序抽象,算是标准的机器学习编译框架之一。它的核心目标,就是描述循环以及相关的硬件加速选项,比如线程、专用硬件指令和内存访问。
为了更直观地理解,咱们先看一个具体的计算序列作为例子:假设有两个 128×128 的矩阵 A 和 B,我们执行以下两步张量计算。
Y_{i, j} &= sum_k A_{i, k} times B_{k, j}
C_{i, j} &= mathbb{relu}(Y_{i, j}) = mathbb{max}(Y_{i, j}, 0)
这个操作其实很常见——神经网络中的线性层加 ReLU 激活函数就是这种模式。下面我们用 TensorIR 来表达它。不过,在正式接触 TensorIR 之前,先来段原生 Python 代码做个热身:
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
Y = np.empty((128, 128), dtype="float32")
for i in range(128):
for j in range(128):
for k in range(128):
if k == 0:
Y[i, j] = 0
Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
for i in range(128):
for j in range(128):
C[i, j] = max(Y[i, j], 0)
理解了 NumPy 版本的实现后,现在来看 TensorIR 的写法。下面的代码块展示了 mm_relu 的 TVMScript 实现——这是一种嵌入在 Python AST 中的领域特定语言。
@tvm.script.ir_module
class MyModule:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
接下来,逐点拆解一下这个 TensorIR 程序中的关键元素。
函数参数与缓冲区
函数参数和 NumPy 版本中的参数一一对应。
# TensorIR
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
...
# NumPy
def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):
...
这里的 A、B、C 用的是 T.Buffer 类型,明确标注了形状 (128, 128) 和数据类型 float32。多出来的这些信息,正好帮助 MLC 编译器生成专注于特定形状和数据类型的优化代码。
中间结果的分配也类似:
# TensorIR
Y = T.alloc_buffer((128, 128), dtype="float32")
# NumPy
Y = np.empty((128, 128), dtype="float32")
循环迭代
循环结构上也有直接的对应关系。 T.grid 是 TensorIR 的语法糖,能一口气写出多层嵌套的迭代器:
# TensorIR 用 `T.grid`
for i, j, k in T.grid(128, 128, 128):
...
# TensorIR 用 `range`
for i in range(128):
for j in range(128):
for k in range(128):
...
# NumPy
for i in range(128):
for j in range(128):
for k in range(128):
...
计算块
最核心的区别出现在计算语句上:TensorIR 引入了一个叫 T.block 的构造。先看代码对比:
# TensorIR
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
# NumPy
vi, vj, vk = i, j, k
if vk == 0:
Y[vi, vj] = 0
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
一个块代表 TensorIR 中基本的计算单元。关键点在于,块携带的信息比 NumPy 代码要多——它明确了一组块轴 vi, vj, vk,以及围绕这些轴的计算约束。
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
vk = T.axis.reduce(128, k)
这三行声明的格式可以抽象为:
[block_axis] = T.axis.[axis_type]([axis_range], [mapped_value])
它们至少表达了三点:
- 指定了
vi、vj、vk(本例中绑定到i、j、k)。 - 声明了这些轴的原生范围(比如
T.axis.spatial(128, i)中的 128)。 - 标明了迭代器的属性——是空间轴还是归约轴。
块轴属性
再深入一点看块轴属性的含义。这些属性描述了轴与当前计算的关系。在这个块里,三个轴 vi、vj、vk 分别读取 A[vi, vk] 和 B[vk, vj],并写入 Y[vi, vj]。这里对 Y 做的是(归约)更新,我们暂且叫它写入,因为不需要从其他地方拿 Y 的值。
重点是:对于固定的 vi 和 vj,这个计算块产生的是 Y 中一个空间位置的点值(Y[vi, vj]),它独立于 Y 中其他位置(不同 vi、vj 的值)。所以 vi、vj 被称作空间轴——它们直接对应块写入缓冲区区域的起始位置。参与归约的轴(vk)自然就是归约轴。
为什么在块中有额外的信息
这额外信息(块轴的范围和属性)让块在执行迭代时完全自包含,不受外部 i, j, k 循环的干扰。换句话说,你可以独立检查每个块是否正确。
此外,块轴信息还提供了一种校验机制——帮助验证外部循环与计算是否匹配。例如下面这段代码就会报错,因为循环期望大小为 128 的迭代器,但实际只绑定了 127:
# 错误的程序,由于循环和块迭代不匹配
for i in range(127):
with T.block("C"):
vi = T.axis.spatial(128, i)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
这里出现错误是因为迭代器大小不匹配
...
块轴绑定的语法
如果每个块轴都直接映射到外部循环迭代器,可以用 T.axis.remap 在单行内完成声明:
# SSR 表示每个轴的属性分别为 "spatial", "spatial", "reduce"
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
这等价于:
vi = T.axis.spatial(range_of_i, i)
vj = T.axis.spatial(range_of_j, j)
vk = T.axis.reduce (range_of_k, k)
因此,我们可以用更简洁的方式重写上面的程序:
@tvm.script.ir_module
class MyModuleWithAxisRemapSugar:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
Y = T.alloc_buffer((128, 128), dtype="float32")
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
这样,代码读起来就清爽多了。TensorIR 这种抽象方式,本质上是在传统循环嵌套的基础上,显式地注入了计算结构和硬件相关的语义——这正是机器学习编译想要自动优化代码的关键前提。
