LeetGPU习题07:Norm系列代码实现
Pytorch调用
实际上PyTorch的调用接口是高度统一的,它们都继承自nn.Module,核心都是沿着某个维度计算统计量,之后标准化,可选仿射变换。
区别在于沿着哪个轴做归一化,以及计算哪些统计量
理解了这个本质,剩下的就是查找参数,看维度。
BatchNorm
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, ...)torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, ...)torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, ...)对于BatchNorm来说,主要有上述三个版本,区别在于它们对于输入的维度布局不同,我们来看一下基本用法:
def batchnorm_example(): x = torch.rand(100, 16, 784) layer = nn.BatchNorm1d(16) out = layer(x)
y = torch.rand(1, 16, 7, 7) layer = nn.BatchNorm2d(16) # 传入通道数C out = layer(y)Pytorch要求显式传入特征通道数,以便可以在内部初始化可学习参数weight和bias。它们的形状都是(C,),
实际上BatchNorm会根据自身是否为training来自动选择后续行为,对于调用者是完全透明的。
学习了官方的调用之后,我们来看一下具体是如何实现的:
def batchnorm(x, running_mean, running_var, weight, bias, training, momentum=0.1, eps=1e-5):
""" x.dim() 返回张量的维度数,例如对于一个形状为 (3, 4, 5) 的张量,x.dim() 将返回 3。 range(x.dim()) 生成一个从 0 到 x.dim() - 1 的序列,例如生成 [0, 1, 2]。 [d for d in range(x.dim()) if d != 1] 是列表推导式,它遍历上述生成的序列,并将不等于 1 的维度索引添加到列表中。 """ dims = [d for d in range(x.dim()) if d != 1]
""" [1: -1]是一个列表,第一个元素是1,第二个元素是-1 [1] * (x.dim() - 2) 是一个列表,包含 x.dim() - 2 个元素,每个元素都是1。 shape表示前两个维度是1和-1,后边的维度全都是1 """ shape = [1, -1] + [1] * (x.dim() - 2)
# 训练模式 if training: mean = x.mean(dim=dims, keepdim=True) # shape: (1, C, 1, 1) var_biased = x.var(dim=dims, keepdim=True, correction=0) # 有偏估计,用于当前批次的归一化 var_unbiased = x.var(dim=dims, keepdim=True, correction=1) # 无偏估计,用于更新全局统计量
with torch.no_grad(): # running_mean 和running_var 形状是(C,), 用squeeze()挤掉多余的维度 running_mean.data = (1 - momentum) * running_mean.data + momentum * mean.squeeze() running_var.data = (1 - momentum) * running_var.data + momentum * var_unbiased.squeeze() var = var_biased else: mean = running_mean.view(shape) var = running_var.view(shape)
# 归一化 x_norm = (x - mean) / torch.sqrt(var + eps) # 仿射变换:y=γx+β if weight is not None: w = weight.view(shape) b = bias.view(shape) return x_norm * w + b return x_norm具体内容如上,关键点我都已经写出来了,大家自行查阅即可。
优化思维时时刻刻都要有,我还写了一个优化版本的,mean、var_unbiased可以同时计算出来:
def manual_batchnorm_v2(x, running_mean, running_var, weight, bias, training, momentum=0.1, eps=1e-5): dims = [d for d in range(x.dim()) if d != 1] shape = [1, -1] + [1] * (x.dim() - 2)
if training: # 计算参与归约的元素总数 n = 1 for d in dims: n *= x.shape[d]
# 一次性计算均值和方差(无偏方差) var_unbiased, mean = torch.var_mean(x, dim=dims, keepdim=True, correction=1) # 转为有偏方差用于归一化 var_biased = var_unbiased * ((n - 1) / n)
with torch.no_grad(): running_mean.data = (1 - momentum) * running_mean.data + momentum * mean.squeeze() running_var.data = (1 - momentum) * running_var.data + momentum * var_unbiased.squeeze() var = var_biased else: mean = running_mean.view(shape) var = running_var.view(shape)
x_norm = (x - mean) / torch.sqrt(var + eps) if weight is not None: return x_norm * weight.view(shape) + bias.view(shape) return x_normLayerNorm
LayerNorm不依赖batch统计量,在NLP和Transformer是绝对主力。
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, ...)# shape可以是int或者tuple,表示输入最后几个维度的大小,归一化在这些维度进行# eps,同BatchNorm# elementwise_affine,是否学习逐元素的weight和bias调用示例如下:
ln = nn.LayerNorm(512)x = torch.randn(32, 10, 512)out = ln.(x) # 形状不变,每个token的512维向量被归一化我们手动实现的版本:
import torchimport torch.nn as nn
class layernorm(nn.Module): def __init__(self, embed_size, eps=1e-5): super().__init__() self.gamma = nn.Parameter(torch.ones(embed_size)) self.beta = nn.Parameter(torch.zeros(embed_size)) self.eps = eps
def forward(self, x): # correction=0表示有偏估计 var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
x_norm = (x - mean) / torch.sqrt(var + self.eps) return x_norm * self.gamma + self.beta
if __name__ == "__main__": x = torch.randn(2, 4, 3) embed_size = x.size(-1) official = nn.LayerNorm(embed_size, eps=1e-5) my = layernorm(embed_size, eps=1e-5)
# 复制相同权重 my.gamma.data = official.weight.data.clone() my.beta.data = official.bias.data.clone()
diff = (official(x) - my(x)).abs().max().item() print(f"Max difference: {diff:.2e}")RMSNorm
RMSNorm是LayerNorm的一个简化变体,直接减去均值计算的步骤,只除以均方根,在最新的大模型里广泛使用。
torch.nn.RMSNorm(normalized_shape, eps=1e-06, elementwise_affine=True, ...)使用方法完全同LayerNorm一致。
| 特性 | BatchNorm | LayerNorm | RMSNorm |
|---|---|---|---|
| 归一化轴 | 沿 batch 和空间维度 (每个通道独立) | 沿特征维度 (每个样本独立) | 沿特征维度 (每个样本独立) |
| 依赖 batch | 是 | 否 | 否 |
| 可学习参数 | weight + bias (通道数) | weight + bias (归一化形状) | 仅 weight |
| 适用场景 | CNN、CV | NLP、Transformer | 大模型 |
| PyTorch 接口 | BatchNorm1d/2d/3d | LayerNorm | RMSNorm |
Triton实现
下边我们核心先讲解LN和RMS的Triton版本,BN就忽略掉,大家可以自行学习。
归一化的视角
无论是LN还是RMS,输入都是一个二维矩阵,其中N是所有token的总行数,K是特征维度。
每一行的计算是完全独立的。 对于第n行,算子聚合该行内所有K个元素,算出指定的统计量,再用统计量去逐元素变换该行。
这是经典的Reduce+Broadcast模式
具体来说,LayerNorm需要两步聚合,分别计算均值和方差;RMSNorm只进行依次平方和聚类。
在GPU上,Reduce是代价最高的部分,Broadcast是element wise操作,因此很明显,归一化操作是带宽密集型。
LayerNOrm
我在这里参考了flash-attenion官方的实现: flash-attention
实际上,除了前向传播,在写绝大多数算子的时候,也要考虑反向传播,后续我会写一个完整的macroTorch,来完整的学习是如何进行前向传播与反向传播的。
关于自动调优,可以写一个函数来囊括进去常用的内容。 同时我在这里保存了mean和rstd,这样能够给未来的反向传播直接复用,避免重新进行规约计算,因此可以看到我显式的进行了两次保存。
并且还有一个设计上的巧思,即步长参数化,而不是假设内存严格连续,这可以让kernel实现更加复杂的张量布局。
import torchimport torch.nn as nnimport tritonimport triton.language as tlfrom triton.testing import do_bench
def get_autotune_configs(): warp_size = 32 max_threads_per_block = 1024 configs = [] for num_warps in [1,2, 4, 8, 16, 32]: if num_warps * warp_size <= max_threads_per_block: configs.append(triton.Config({}, num_warps=num_warps)) return configs
@triton.autotune( configs=get_autotune_configs(), key=["N"],)@triton.jitdef layer_norm_fwd_kernel( X_ptr, Y_ptr, W_ptr, B_ptr, Mean_ptr, Rstd_ptr, stride_x_row, stride_y_row, # 传入行步长即可灵活索引,Triton编程常见模式 N, eps, BLOCK_N: tl.constexpr): # 每个program处理1行 row_idx = tl.program_id(0) X_row_ptr = X_ptr + row_idx * stride_x_row Y_row_ptr = Y_ptr + row_idx * stride_y_row
cols = tl.arange(0, BLOCK_N) mask = cols < N
# 加载该行的所有元素,在计算中提升为FP32 x = tl.load(X_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) w = tl.load(W_ptr + cols, mask=mask).to(tl.float32) b = tl.load(B_ptr + cols, mask=mask).to(tl.float32)
# 规约:均值 and 方差 mean = tl.sum(x, axis=0) / N tl.store(Mean_ptr + row_idx, mean)
x_bar = tl.where(mask, x - mean, 0.0) var = tl.sum(x_bar * x_bar, axis=0) / N rstd = 1.0 / tl.sqrt(var + eps) tl.store(Rstd_ptr + row_idx, rstd)
# 广播:归一化+仿射变换 y = (x - mean) * rstd * w + b tl.store(Y_row_ptr + cols, y, mask=mask)
def layer_norm_fwd(x, weight, bias=None, eps=1e-5): M, N = x.shape y = torch.empty_like(x) mean = torch.empty(M, device=x.device, dtype=torch.float32) rstd = torch.empty(M, device=x.device, dtype=torch.float32)
if bias is None: bias = torch.zeros(N, device=x.device, dtype=weight.dtype) if weight is None: weight = torch.ones(N, device=x.device, dtype=x.dtype)
BLOCK_N = triton.next_power_of_2(N) layer_norm_fwd_kernel[(M,)]( x, y, weight, bias, mean, rstd, x.stride(0), y.stride(0), N, eps, BLOCK_N=BLOCK_N ) return y, mean, rstd
def test_correctness(shapes=[(128, 256), (512, 1024)]): for M, N in shapes: x = torch.randn(M, N, device='cuda', dtype=torch.float32) weight = torch.randn(N, device='cuda', dtype=torch.float32) bias = torch.randn(N, device='cuda', dtype=torch.float32) eps = 1e-5
ln = nn.LayerNorm(N, eps=eps).to('cuda') ln.weight.data = weight ln.bias.data = bias y_ref = ln(x)
y_tri, _, _ = layer_norm_fwd(x, weight, bias, eps)
max_diff = (y_tri - y_ref).abs().max().item() print(f"Shape ({M}, {N}): max diff = {max_diff:.6e}")
bench_perf_report = Nonetry: from triton.testing import Benchmark, perf_report
@perf_report( Benchmark( x_names=["N"], x_vals=[256, 512, 1024, 2048, 4096, 8192], line_arg="provider", line_vals=["triton", "pytorch"], line_names=["Triton", "PyTorch"], styles=[("blue", "-"), ("red", "-")], ylabel="Latency (ms)", plot_name="LayerNorm Fwd Performance", args={"M": 1024, "eps": 1e-5, "dtype": torch.float32}, ) ) def _bench_perf_report(M, N, eps, dtype, provider): device = 'cuda' x = torch.randn(M, N, device=device, dtype=dtype) weight = torch.randn(N, device=device, dtype=dtype) bias = torch.randn(N, device=device, dtype=dtype)
if provider == "triton": def run(): return layer_norm_fwd(x, weight, bias, eps) else: ln = nn.LayerNorm(N, eps=eps).to(device) ln.weight.data = weight ln.bias.data = bias def run(): return ln(x) return do_bench(run, quantiles=[0.5, 0.2, 0.8])
bench_perf_report = _bench_perf_reportexcept ImportError: print("当前 Triton 版本不支持 perf_report,跳过高阶绘图功能\n")
if __name__ == "__main__": test_correctness() if bench_perf_report is not None: bench_perf_report.run(show_plots=True, print_data=True, save_path="./layer_norm_fwd_perf.png")性能对比如下:

RMSNorm
LayerNorm需要计算两个统计量:mean和std。 mean就得涉及到规约,而RMSNorm只计算一个:
rms = sqrt(E[x^2] + eps)y = x / rms * γ核心代码如下:
@triton.autotune( configs=autotune_configs(), key=["N"],)@triton.jitdef rms_norm_fwd_kernel( X_ptr, Y_ptr, W_ptr, stride_x_row, stride_y_row, N, eps, BLOCK_N: tl.constexpr,): row_idx = tl.program_id(0) X_row_ptr = X_ptr + row_idx * stride_x_row Y_row_ptr = Y_ptr + row_idx * stride_y_row
cols = tl.arange(0, BLOCK_N) mask = cols < N
x = tl.load(X_row_ptr + cols, mask=mask, other=0.0).to(tl.float32) w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# RMSNorm: rstd = rsqrt(mean(x^2) + eps) x2 = x * x mean_x2 = tl.sum(x2, axis=0) / N rstd = tl.rsqrt(mean_x2 + eps) # 广播:缩放+仿射变换 y = x * rstd * w tl.store(Y_row_ptr + cols, y, mask=mask)
def rms_norm_fwd(x, weight, eps=1e-5): M, N = x.shape y = torch.empty_like(x)
if weight is None: weight = torch.ones(N, device=x.device, dtype=x.dtype)
BLOCK_N = triton.next_power_of_2(N) rms_norm_fwd_kernel[(M,)]( x, y, weight, x.stride(0), y.stride(0), N, eps, BLOCK_N=BLOCK_N, ) return y
RMS在FP16下的优势更加明显,因为Triton对混合精度路径的控制更加精细,避免了不必要的类型转换开销。
====================================================================================================RMSNorm Performance: dtype=torch.float16==================================================================================================== M N Triton(ms) PyTorch (ms) my vs PT---------------------------------------------------------------------------------------------------- 128 256 0.003816 0.021189 5.55x 128 512 0.004489 0.021973 4.89x 128 1024 0.004901 0.023449 4.78x 128 2048 0.005038 0.026430 5.25x 128 4096 0.006675 0.032510 4.87x 128 8192 0.008356 0.038082 4.56x 512 256 0.004816 0.022823 4.74x 512 512 0.005085 0.025100 4.94x 512 1024 0.006658 0.029662 4.45x 512 2048 0.008573 0.038389 4.48x 512 4096 0.011734 0.058861 5.02x 512 8192 0.022541 0.093368 4.14x 1024 256 0.005728 0.025211 4.40x 1024 512 0.006578 0.029643 4.51x 1024 1024 0.009162 0.037848 4.13x 1024 2048 0.011643 0.057440 4.93x 1024 4096 0.020953 0.093154 4.45x 1024 8192 0.041657 0.170544 4.09x 2048 256 0.006578 0.029480 4.48x 2048 512 0.007969 0.037949 4.76x 2048 1024 0.011659 0.057081 4.90x 2048 2048 0.021078 0.092297 4.38x 2048 4096 0.041632 0.169116 4.06x 2048 8192 0.077019 0.681771 8.85x 4096 256 0.007757 0.037739 4.87x 4096 512 0.012584 0.057165 4.54x 4096 1024 0.020957 0.091967 4.39x 4096 2048 0.040349 0.168362 4.17x 4096 4096 0.077127 0.669127 8.68x 4096 8192 0.149073 1.469917 9.86x
====================================================================================================RMSNorm Performance: dtype=torch.float32==================================================================================================== M N Triton(ms) PyTorch (ms) my vs PT---------------------------------------------------------------------------------------------------- 128 256 0.004033 0.016027 3.97x 128 512 0.004791 0.016718 3.49x 128 1024 0.005044 0.018569 3.68x 128 2048 0.006306 0.021332 3.38x 128 4096 0.009204 0.026772 2.91x 128 8192 0.013090 0.030454 2.33x 512 256 0.005348 0.017874 3.34x 512 512 0.006673 0.019975 2.99x 512 1024 0.008565 0.023732 2.77x 512 2048 0.012434 0.031180 2.51x 512 4096 0.020615 0.048329 2.34x 512 8192 0.043017 0.077455 1.80x 1024 256 0.006282 0.019679 3.13x 1024 512 0.008333 0.023323 2.80x 1024 1024 0.012538 0.030153 2.40x 1024 2048 0.020061 0.046295 2.31x 1024 4096 0.043250 0.076708 1.77x 1024 8192 0.079363 0.145954 1.84x 2048 256 0.008264 0.023329 2.82x 2048 512 0.011833 0.030187 2.55x 2048 1024 0.023103 0.046089 1.99x 2048 2048 0.042146 0.075980 1.80x 2048 4096 0.078298 0.144876 1.85x 2048 8192 0.150213 0.479396 3.19x 4096 256 0.012040 0.030257 2.51x 4096 512 0.020640 0.045097 2.18x 4096 1024 0.040291 0.075542 1.87x 4096 2048 0.079232 0.144678 1.83x 4096 4096 0.150072 0.465174 3.10x 4096 8192 0.296848 1.028277 3.46xFP32下加速比在2-4倍之间,FP16下可达4-9倍。
CUDA实现
二者对于每一行N,都会遍历该行的所有元素,计算统计量(平方和、均值和方差等),并且使用这些统计量做归一化变换。
实际上,Norm的核心分为两个阶段:
阶段 1 — Reduce(归约): 每一行的 K 个元素 → 汇总成 1 个统计量(标量)
例如:RMSNorm 把一行 8192 个 float 聚合成一个平方和 LayerNorm 把一行 8192 个 float 先聚合成均值,再聚合成方差
阶段 2 — Broadcast(广播): 用这 1 个标量去变换该行的每一个元素
y[n][k] = f(x[n][k], stat[n])用 C++ 伪代码表达就是:
for (int n = 0; n < N; n++) { // N 行,行间独立 float stat = 0; for (int k = 0; k < K; k++) { // Reduce: K 个元素 → 1 个值 stat += compute(x[n][k]); } stat = finalize(stat); // 例如 rsqrt(stat/K + eps)
for (int k = 0; k < K; k++) { // Broadcast: 1 个值 → K 个元素 y[n][k] = transform(x[n][k], stat, gamma[k], beta[k]); }}所有归一化的算子,本质上都是Reduce + Broadcast
支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!