Triton学习之路[1]:掌握模板代码

1552 字
8 分钟
Triton学习之路[1]:掌握模板代码
2026-04-12

Triton、CUDA与Pytorch的核心差异#

为什么我们还需要自己写算子?#

在深度学习的模型开发中,Pytorch无疑是当下的绝对霸主。它提供了极其丰富的算子库,让我们能够用搭积木的方式快速构建复杂的神经网络。

然而,在大模型时代,标准的Pytorch有时已经无法满足我们对于极致性能的需求,当现有算子组合不够高效时,手写算子就成了我们的必经之路。

OpenAI推出的Triton正是为了降低CUDA开发的编程门槛而设计的,它在易用性和性能之间找到了一个很好的平衡点。

三种框架的编程范式对比#

为了理解这Pytorch、CUDA和Triton三种框架的差异,我们需要仔细去考虑一下他们的对待数据与计算的方式。

Pytorch#

Pytorch是一个极其上层的框架。在Pytorch中,最小的计算单位是Tensor

例如,我们可以写出:C = A + B,我们完全不需要关心矩阵乘法在GPU内部的调度、内存管理等细节。 它专注于数学和功能逻辑,屏蔽了底层所有的硬件细节。

CUDA#

写CUDA代码时,程序员需要把视角切换到单个线程上。 你需要手动管理Grid、Block,然后控制每一个Thread的行为,例如线程读取哪块内存?计算什么元素?何时同步等。

int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
C[i] = A[i] + B[i]
}

此外,你还需要手动管理共享内存、寄存器使用,才能榨干 GPU 的性能。

Triton#

Triton 提供了一种折中方案:分块编程范式。 你不再需要管理单个线程,而是以 Tile为粒度来描述计算。

@triton.jit
def add_kernel(A_ptr, B_ptr, C_ptr, N, BLOCK_SIZE: tl.constexpr):
# 获取当前Program的ID
pid = tl.program_id(axis=0)
# 获取Tile内的所有偏移
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
a = tl.load(A_ptr + offsets)
b = tl.load(B_ptr + offsets)
c = a + b
tl.store(C_ptr+offsets, c)

仔细观察上边的代码,看起来我们似乎像是在写标量计算,但是tl.arange生成了一个向量化的偏移数组,tl.load一次性加载了一整块数据,Triton编译器会自动将这块的计算映射到底层线程上。

对比维度#

对比维度CUDATritonPytorch
编程语言C++Python+装饰器标注Python
线程管理显式线程Grid、Block、Thread分块Program张量运算
共享内存手动分配与管理编译器自动分配管理完全透明
上手难度
平台NvidiaNvidia、AMD等大多数现代GPU
打个比方

CUDA像是专业餐厅里的猛火灶,上手难度高,需要牛逼的大厨才能炒出牛逼的菜。

Triton 家用小灶,上手难度适中,炒两个家常菜不成问题,并且也很可口。

Pytorch 预制菜料包。加热就能使用,但是口味固定。

Triton框架的完整结构#

一个完整的Triton算子项目框架如下:

Triton算子项目完整框架
Triton算子项目完整框架

基础的Kernel#

这是最核心基础的部分, 大部分Triton Kernel都遵循以下模板:

import torch
import triton
import triton.language as tl
@triton.jit
def my_kernel(
input_ptr, # 输入数据指针
output_ptr, # 输出数据指针
N, # 总元素数量
BLOCK_SIZE: tl.constexpr, # 编译期常量,每个 Block 处理的元素数
):
# 1. 获取当前 Block 的 ID
pid = tl.program_id(axis=0)
# 2. 计算当前 Block 负责的偏移量范围
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3. 生成越界保护掩码
mask = offsets < N
# 4. 从全局内存加载数据
x = tl.load(input_ptr + offsets, mask=mask)
# 5. 执行计算(核心逻辑)
y = your_computation(x)
# 6. 将结果写回全局内存
tl.store(output_ptr + offsets, y, mask=mask)

Host端封装函数#

Kernel并不能直接被用户调用,需要一个Python函数来进行封装:

def my_operator(x: torch.Tensor) -> torch.Tensor:
# 确保输入在 GPU 上 并且传入的内存连续
assert x.is_cuda and x.is_contiguous()
# 预分配输出内存
output = torch.empty_like(x)
N = x.numel()
# 配置执行参数
BLOCK_SIZE = 256 # 可选的块大小
# triton.cdiv向上取整,计算需要多少个Block能够覆盖所有元素
grid = (triton.cdiv(N, BLOCK_SIZE),) # grid是一个元组,表示启动1D、2D或者3D的网格,这里使用(x,)确保它是一个元组
# 启动 Kernel
my_kernel[grid](x, output, N, BLOCK_SIZE=BLOCK_SIZE)
return output

自动调优#

在实际项目中,BLOCK_SIZE等参数需要根据问题的规模来动态选择,因此Triton提供了@triton.autotune装饰器:

@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=4), # 每个Block分配的Warp的数量,CUDA中学过,一个Warp中有32个线程
triton.Config({'BLOCK_SIZE': 256}, num_warps=8),
triton.Config({'BLOCK_SIZE': 512}, num_warps=16),
],
key=['N'],
)
@triton.jit
def my_kernel_autotuned(input_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr):
# Kernel 内容同上
...

基准测试以及正确性验证#

那实现了基本功能之后,我们需要对其进行配套测试:

def test_operator():
# 生成测试数据
x = torch.randn(10000, device='cuda')
# 基准实现(通常用 PyTorch 原生算子)
y_ref = torch.sigmoid(x)
# Triton 实现
y_triton = my_operator(x)
# 正确性验证
torch.testing.assert_close(y_triton, y_ref, rtol=1e-3, atol=1e-5)
# 性能对比(使用 triton.testing.do_bench)
ms_ref = triton.testing.do_bench(lambda: torch.sigmoid(x))
ms_triton = triton.testing.do_bench(lambda: my_operator(x))
print(f"PyTorch: {ms_ref:.4f} ms, Triton: {ms_triton:.4f} ms")

Pytorch算子注册#

对于常用的算子,我们可以注册在Pytorch的torch.ops中。

import torch.library
# 定义算子库
lib = torch.library.Library("my_ops", "DEF")
lib.define("sigmoid_custom(Tensor x) -> Tensor")
# 注册 Triton 实现
@torch.library.impl(lib, "sigmoid_custom", "CUDA")
def sigmoid_custom_cuda(x):
return my_operator(x)
# 使用时直接调用
y = torch.ops.my_ops.sigmoid_custom(x)

写Triton时需要考虑的关键问题#

我们通常在写一个triton算子时,需要考虑几个问题:

正确性

数据类型是否支持?输入指针是否有效?数据规模是否合理?边界处理是否正确?

性能

BLOCK_SIZEnum_warps如何调优?是否需要共享内存?是否需要多维度Grid?是否需要流水线预取?

可维护性

算子是否需要集成?是否需要支持多种GPU架构?是否具有足够的可扩展性?

掌握了这个模板,你就拥有了编写绝大多数Triton算子的基础框架,后续学习可以更多的往里边填充具体的计算逻辑和优化策略即可。

支持与分享

如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!

赞助
Triton学习之路[1]:掌握模板代码
https://dlog.com.cn/posts/triton01/basic/
作者
杜子源
发布于
2026-04-12
许可协议
CC BY-NC-SA 4.0
Profile Image of the Author
杜子源
都是风景,幸会
公告
如果需要原图,请私信联系我。
音乐
封面

音乐

暂未播放

0:00 0:00
暂无歌词
分类
标签
站点统计
文章
10
分类
6
标签
7
总字数
16,881
运行时长
0
最后活动
0 天前

目录