重写后的HTML内容如下:
GPU编程初看常被视作黑盒技艺——满目皆是 *warps*、*shared memory*、*tensor cores*,以及 kernel 中繁复的索引计算。不必焦虑,本文从一个实际案例切入,助你逐步掌握 Triton:手写一个完整的 softmax kernel。
我们以官方 Triton 教程为蓝本,剖析代码背后的核心原理,并辅以手绘示意图。若你曾觉得 GPU 编程教程过于抽象,本文恰好适合作为起点。目标不仅是写出可运行的 kernel,更是真正理解现代 AI 负载在 GPU 上的执行路径。
最后,我们将该 kernel 部署到 RTX 5090 上,与 PyTorch 原生 softmax 进行基准对比。结果并非简单的“Triton 绝对胜出”——其中存在一个性能断崖,这一教训恰恰揭示了 GPU 编程中至关重要的权衡。
Softmax:简洁数学,潜伏的内存瓶颈

按行计算 softmax 的数学表达极为简明:每行是一个独立的 logit 向量,softmax 将其转换为概率分布。例如一个 2×3 矩阵,并非对六个数值统一计算 softmax,而是分别对第 0 行和第 1 行独立操作。
真正的挑战并非数学本身,而是 GPU 端如何高效执行:数据需要搬运几次、中间结果驻留在何处、GPU 的计算时间是否被内存等待所吞噬。
朴素的 PyTorch 实现会将 softmax 拆解为若干独立的张量操作:max、减法、指数、求和、除法。每一步都可能从全局内存读取数据,再将中间值写回。
而融合的 Triton kernel 颠覆了这一模式:一次性加载一整行,所有 softmax 步骤在数据滞留于片上时完成,最后一次性写回最终结果。
这里的“片外”指 GPU 全局内存(DRAM):容量大但延迟高;“片上”指 GPU 计算单元内部的内存(寄存器或共享内存/SRAM):速度极快但容量极小。从概念上看,一个 Triton 程序处理一行,但实际运行时,大量 Triton 程序并行执行。
一个简单的 Triton 模型
在深入 softmax kernel 之前,先搭建一个简易模型。
假设有一个长度为 3072 的向量 X,需要为每个元素减 1。CPU 的思路是顺序循环:
for i in range(3072):
X[i] = X[i] - 1
在 GPU 上则截然不同——GPU 会将向量切分成块,并行处理。Triton 中,一个 kernel 描述单个程序实例的行为。启动 kernel 时,会创建一个网格,其内包含多个并行运行的程序实例。
BLOCK_SIZE = 1024
每个程序实例处理 1024 个元素:
3072 / 1024 = 3 → 需要 3 个程序实例。
program 0 → elements 0-1023
program 1 → elements 1024-2047
program 2 → elements 2048-3071
每个程序实例获取自身的 `program_id`,据此定位数据切片,执行相同操作。Softmax kernel 同理,只是每个程序实例处理矩阵的一行,而非向量的一块。
逐行拆解 Triton Softmax Kernel
一个 Triton 程序实例一次处理一行。当启动的程序数少于行数时,每个程序以固定步长在矩阵中跳跃,处理多行。
@triton.jit
def softmax_kernel(output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_rows, n_cols,
BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr,):
row_start = tl.program_id(0) # 当前程序实例 ID
row_step = tl.num_programs(0) # 轴 0 上的实例总数
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# ...
`tl.program_id(0)` 获取当前实例的 id。若启动了 4 个程序,program 0 从 row 0 开始,program 1 从 row 1 开始,依此类推,每个程序按 `row_step` 跳跃处理后续行。
`row_stride` 告知程序在内存中前进多少字节才能到达下一行的起始位置。一个常见误区是认为下一行总在 `n_cols` 个元素之后开始——这对连续紧凑的张量成立,但并非所有布局都如此。
# 指向当前行在内存中的起始位置
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets

这里需要区分两个概念:`n_cols` 是逻辑列数,`input_row_stride` 是两行之间的物理内存间距。
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
mask 告知 Triton 仅加载实际列,虚假列用 `-inf` 填充,因为 `exp(-inf) = 0`,不影响 softmax 分母。
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
先减最大值是为了数值稳定——不改变 softmax 结果,但防止指数溢出。这些操作全部发生在同一个融合的 Triton 程序内——`row_minus_max`、`numerator`、`denominator` 不会作为中间张量写回全局内存。
启动 Kernel:Python 包装器
Triton kernel 描述了程序实例内部的行为,但实际问题还需 Python 代码回答:块多大?多少 warp?启动几个程序?
def softmax(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# ...
选择 2 的幂的 BLOCK_SIZE——契合 Triton 的块编程模型和归约操作。一行 3000 列?BLOCK_SIZE 就用 4096,多余的部分用 mask 屏蔽。
num_warps = 8
Warp 是一组同步执行的 GPU 线程,`num_warps = 8` 意味着每个 Triton 程序实例使用 8 个 warp。

num_stages = 4 if SIZE_SMEM > 200000 else 2
`num_stages` 与其他参数不同——它帮助同一程序内的循环迭代重叠,例如一轮加载、一轮计算、一轮写入同时进行。但更多阶段会消耗更多片上资源,未必更好。

y = torch.empty_like(x)
为输出分配与输入 shape、dtype、device 相同的张量。
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0),
n_rows, n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps,
grid=(1,),)
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
先编译一次 kernel,查看一个程序实例消耗多少寄存器与共享内存。

GPU 流多处理器资源有限。每个 SM 拥有固定的寄存器与共享内存预算。一个程序占用过多,同一 SM 能同时运行的程序数量就会减少。
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
占用率受限于最先耗尽的资源。采用持久化风格的 kernel:不为每行启动一个程序,而是启动足够多的程序占满 GPU,每个程序循环处理多行。
基准测试

中小行大小下,PyTorch 更快,这符合预期。但在 N ≈ 8700 附近,两侧均遭遇性能断崖。此后 Triton kernel 反超。
这并不意味着 Triton 总是更快——GPU 性能高度依赖张量形状、块大小、资源使用。注意 y 轴展示的是有效带宽,根据输入输出张量大小计算,而非内部内存事务。在 Triton 实现中,N 超过 8192 后,BLOCK_SIZE 跳至 16384,每个程序实例内部处理更大块,资源压力骤增,性能出现突变。
总结
Triton 提供了一种在近乎 Python 层面编写 GPU kernel 的途径。这个示例也揭示了一个事实:并非 Triton 始终快于 PyTorch——因为 PyTorch 已高度优化。但当你需要自定义算子或操作融合时,Triton 是一款极为趁手的工具。理解背后的资源模型与占用率,才是写出高效 kernel 的关键。