Triton ops教程:高效编程技巧与实战案例
Triton 是为并行计算量身打造的语言与编译器,核心目标非常明确:在 Python 开发环境下,让你高效编写自定义深度学习算子,并在现代 GPU 上实现接近理论峰值的吞吐量。说白了,就是用 Python 的简洁性,写出逼近硬件极限性能的代码。
接下来,我们逐一拆解 Triton 中最常用的几个归约函数。掌握这些,基本能搞定大部分极值查找与索引相关的计算需求。
triton_language.argmax
triton.language.argmax(input, axis, tie_break_left=True, keep_dims=False)
该函数负责找出张量 input 在指定维度 axis 上所有元素中最大值的索引位置。这是一个标准的归约操作,最终返回索引值。
参数详解:
- input (Tensor) - 输入的高维张量数据。
- axis (int) - 执行归约的维度编号。
- keep_dims (bool) - 设为 True 时保留被归约的维度,使其长度为 1,输出形状与输入维度数一致。
- tie_break_left (bool) - 当多个最大值相等时,决定返回哪个索引。默认 True 表示取最左侧的非 NaN 值索引。
同时支持作为张量对象的方法调用,例如 x.argmax(...),写法更直观。
triton_language.argmin
triton.language.argmin(input, axis, tie_break_left=True, keep_dims=False)
逻辑与 argmax 完全对称,但查找的是最小值对应的索引。
参数详解:
- input (Tensor) - 输入张量。
- axis (int) - 归约维度。
- keep_dims (bool) - 是否保留归约维度。
- tie_break_left (bool) - 最小值出现平局时,默认返回最左侧的索引。
同样可用 x.argmin(...) 形式调用。
triton_language.max
triton.language.max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)
与 argmax 不同,max 直接返回最大值本身,但还额外支持一并返回对应索引。
参数详解:
- input (Tensor) - 输入张量。
- axis (int) - 归约维度。默认为 None,表示对所有元素求最大值。
- keep_dims (bool) - 是否保留归约维度。
- return_indices (bool) - 设为 True 时,同时返回最大值对应的索引。
- return_indices_tie_break_left (bool) - 最大值平局且需要返回索引时,默认取最左侧的索引。
成员函数调用形式:x.max(...)。
triton_language.min
triton.language.min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)
逻辑与 max 一致,但返回的是最小值。
参数详解:
- input (Tensor) - 输入张量。
- axis (int) - 归约维度。
- keep_dims (bool) - 是否保留归约维度。
- return_indices (bool) - 是否同时返回最小值索引。
- return_indices_tie_break_left (bool) - 最小值平局时,默认返回最左侧的索引。
同样支持 x.min(...) 调用。
triton_language.reduce
triton.language.reduce(input, axis, combine_fn, keep_dims=False)
这是一个通用归约操作,允许你通过自定义组合函数 combine_fn 决定如何合并张量元素。
参数详解:
- input (Tensor | tuple of Tensors) - 输入张量或张量元组。
- axis (int | None) - 归约维度。设为 None 时归约所有维度。
- combine_fn (Callable) - 接收两个标量张量并返回组合结果的函数。此函数必须用
@triton.jit装饰。 - keep_dims (bool) - 是否保留归约维度。
支持成员函数调用:x.reduce(...)。
triton_language.sum
triton.language.sum(input, axis=None, keep_dims=False)
最基础的求和操作,计算张量 input 在指定 axis 上所有元素的总和。
参数详解:
- input (Tensor) - 输入张量。
- axis (int) - 归约维度。
- keep_dims (bool) - 是否保留归约维度。
也可写作 x.sum(...)。
triton_language.xor_sum
triton.language.xor_sum(input, axis=None, keep_dims=False)
该函数计算“异或和”。沿指定 axis 对张量 input 中的元素逐个执行按位异或(XOR)并返回结果。
参数详解:
- input (Tensor) - 输入张量。
- axis (int) - 归约维度。
- keep_dims (bool) - 是否保留归约维度。
同样支持成员函数调用:x.xor_sum(...)。
