Triton学习之路[1]:掌握模板代码
Triton、CUDA与Pytorch的核心差异
为什么我们还需要自己写算子?
在深度学习的模型开发中,Pytorch无疑是当下的绝对霸主。它提供了极其丰富的算子库,让我们能够用搭积木的方式快速构建复杂的神经网络。
然而,在大模型时代,标准的Pytorch有时已经无法满足我们对于极致性能的需求,当现有算子组合不够高效时,手写算子就成了我们的必经之路。
OpenAI推出的Triton正是为了降低CUDA开发的编程门槛而设计的,它在易用性和性能之间找到了一个很好的平衡点。
三种框架的编程范式对比
为了理解这Pytorch、CUDA和Triton三种框架的差异,我们需要仔细去考虑一下他们的对待数据与计算的方式。
Pytorch
Pytorch是一个极其上层的框架。在Pytorch中,最小的计算单位是Tensor。
例如,我们可以写出:C = A + B,我们完全不需要关心矩阵乘法在GPU内部的调度、内存管理等细节。
它专注于数学和功能逻辑,屏蔽了底层所有的硬件细节。
CUDA
写CUDA代码时,程序员需要把视角切换到单个线程上。 你需要手动管理Grid、Block,然后控制每一个Thread的行为,例如线程读取哪块内存?计算什么元素?何时同步等。
int idx = blockIdx.x * blockDim.x + threadIdx.x;if (idx < N) { C[i] = A[i] + B[i]}此外,你还需要手动管理共享内存、寄存器使用,才能榨干 GPU 的性能。
Triton
Triton 提供了一种折中方案:分块编程范式。 你不再需要管理单个线程,而是以 Tile为粒度来描述计算。
@triton.jitdef add_kernel(A_ptr, B_ptr, C_ptr, N, BLOCK_SIZE: tl.constexpr): # 获取当前Program的ID pid = tl.program_id(axis=0) # 获取Tile内的所有偏移 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < N
a = tl.load(A_ptr + offsets) b = tl.load(B_ptr + offsets)
c = a + b
tl.store(C_ptr+offsets, c)仔细观察上边的代码,看起来我们似乎像是在写标量计算,但是tl.arange生成了一个向量化的偏移数组,tl.load一次性加载了一整块数据,Triton编译器会自动将这块的计算映射到底层线程上。
对比维度
| 对比维度 | CUDA | Triton | Pytorch |
|---|---|---|---|
| 编程语言 | C++ | Python+装饰器标注 | Python |
| 线程管理 | 显式线程Grid、Block、Thread | 分块Program | 张量运算 |
| 共享内存 | 手动分配与管理 | 编译器自动分配管理 | 完全透明 |
| 上手难度 | 高 | 中 | 低 |
| 平台 | Nvidia | Nvidia、AMD等 | 大多数现代GPU |
CUDA像是专业餐厅里的猛火灶,上手难度高,需要牛逼的大厨才能炒出牛逼的菜。
Triton 家用小灶,上手难度适中,炒两个家常菜不成问题,并且也很可口。
Pytorch 预制菜料包。加热就能使用,但是口味固定。
Triton框架的完整结构
一个完整的Triton算子项目框架如下:

基础的Kernel
这是最核心基础的部分, 大部分Triton Kernel都遵循以下模板:
import torchimport tritonimport triton.language as tl
@triton.jitdef my_kernel( input_ptr, # 输入数据指针 output_ptr, # 输出数据指针 N, # 总元素数量 BLOCK_SIZE: tl.constexpr, # 编译期常量,每个 Block 处理的元素数): # 1. 获取当前 Block 的 ID pid = tl.program_id(axis=0)
# 2. 计算当前 Block 负责的偏移量范围 block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3. 生成越界保护掩码 mask = offsets < N
# 4. 从全局内存加载数据 x = tl.load(input_ptr + offsets, mask=mask)
# 5. 执行计算(核心逻辑) y = your_computation(x)
# 6. 将结果写回全局内存 tl.store(output_ptr + offsets, y, mask=mask)Host端封装函数
Kernel并不能直接被用户调用,需要一个Python函数来进行封装:
def my_operator(x: torch.Tensor) -> torch.Tensor:
# 确保输入在 GPU 上 并且传入的内存连续 assert x.is_cuda and x.is_contiguous()
# 预分配输出内存 output = torch.empty_like(x) N = x.numel()
# 配置执行参数 BLOCK_SIZE = 256 # 可选的块大小 # triton.cdiv向上取整,计算需要多少个Block能够覆盖所有元素 grid = (triton.cdiv(N, BLOCK_SIZE),) # grid是一个元组,表示启动1D、2D或者3D的网格,这里使用(x,)确保它是一个元组
# 启动 Kernel my_kernel[grid](x, output, N, BLOCK_SIZE=BLOCK_SIZE)
return output自动调优
在实际项目中,BLOCK_SIZE等参数需要根据问题的规模来动态选择,因此Triton提供了@triton.autotune装饰器:
@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 128}, num_warps=4), # 每个Block分配的Warp的数量,CUDA中学过,一个Warp中有32个线程 triton.Config({'BLOCK_SIZE': 256}, num_warps=8), triton.Config({'BLOCK_SIZE': 512}, num_warps=16), ], key=['N'],)
@triton.jitdef my_kernel_autotuned(input_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr): # Kernel 内容同上 ...基准测试以及正确性验证
那实现了基本功能之后,我们需要对其进行配套测试:
def test_operator(): # 生成测试数据 x = torch.randn(10000, device='cuda')
# 基准实现(通常用 PyTorch 原生算子) y_ref = torch.sigmoid(x)
# Triton 实现 y_triton = my_operator(x)
# 正确性验证 torch.testing.assert_close(y_triton, y_ref, rtol=1e-3, atol=1e-5)
# 性能对比(使用 triton.testing.do_bench) ms_ref = triton.testing.do_bench(lambda: torch.sigmoid(x)) ms_triton = triton.testing.do_bench(lambda: my_operator(x))
print(f"PyTorch: {ms_ref:.4f} ms, Triton: {ms_triton:.4f} ms")Pytorch算子注册
对于常用的算子,我们可以注册在Pytorch的torch.ops中。
import torch.library
# 定义算子库lib = torch.library.Library("my_ops", "DEF")lib.define("sigmoid_custom(Tensor x) -> Tensor")
# 注册 Triton 实现@torch.library.impl(lib, "sigmoid_custom", "CUDA")def sigmoid_custom_cuda(x): return my_operator(x)
# 使用时直接调用y = torch.ops.my_ops.sigmoid_custom(x)写Triton时需要考虑的关键问题
我们通常在写一个triton算子时,需要考虑几个问题:
数据类型是否支持?输入指针是否有效?数据规模是否合理?边界处理是否正确?
BLOCK_SIZE和num_warps如何调优?是否需要共享内存?是否需要多维度Grid?是否需要流水线预取?
算子是否需要集成?是否需要支持多种GPU架构?是否具有足够的可扩展性?
掌握了这个模板,你就拥有了编写绝大多数Triton算子的基础框架,后续学习可以更多的往里边填充具体的计算逻辑和优化策略即可。
支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!