LeetGPU习题07:Norm系列代码实现

2948 字
15 分钟
LeetGPU习题07:Norm系列代码实现
2026-06-12

Pytorch调用#

实际上PyTorch的调用接口是高度统一的,它们都继承自nn.Module,核心都是沿着某个维度计算统计量,之后标准化,可选仿射变换。

BN、LN、RMS的区别

区别在于沿着哪个轴做归一化,以及计算哪些统计量

理解了这个本质,剩下的就是查找参数,看维度。

BatchNorm#

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, ...)
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, ...)
torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, ...)

对于BatchNorm来说,主要有上述三个版本,区别在于它们对于输入的维度布局不同,我们来看一下基本用法:

def batchnorm_example():
x = torch.rand(100, 16, 784)
layer = nn.BatchNorm1d(16)
out = layer(x)
y = torch.rand(1, 16, 7, 7)
layer = nn.BatchNorm2d(16) # 传入通道数C
out = layer(y)

Pytorch要求显式传入特征通道数,以便可以在内部初始化可学习参数weight和bias。它们的形状都是(C,), 实际上BatchNorm会根据自身是否为training来自动选择后续行为,对于调用者是完全透明的。

学习了官方的调用之后,我们来看一下具体是如何实现的:

def batchnorm(x, running_mean, running_var, weight, bias, training, momentum=0.1, eps=1e-5):
"""
x.dim() 返回张量的维度数,例如对于一个形状为 (3, 4, 5) 的张量,x.dim() 将返回 3。
range(x.dim()) 生成一个从 0 到 x.dim() - 1 的序列,例如生成 [0, 1, 2]。
[d for d in range(x.dim()) if d != 1] 是列表推导式,它遍历上述生成的序列,并将不等于 1 的维度索引添加到列表中。
"""
dims = [d for d in range(x.dim()) if d != 1]
"""
[1: -1]是一个列表,第一个元素是1,第二个元素是-1
[1] * (x.dim() - 2) 是一个列表,包含 x.dim() - 2 个元素,每个元素都是1。
shape表示前两个维度是1和-1,后边的维度全都是1
"""
shape = [1, -1] + [1] * (x.dim() - 2)
# 训练模式
if training:
mean = x.mean(dim=dims, keepdim=True) # shape: (1, C, 1, 1)
var_biased = x.var(dim=dims, keepdim=True, correction=0) # 有偏估计,用于当前批次的归一化
var_unbiased = x.var(dim=dims, keepdim=True, correction=1) # 无偏估计,用于更新全局统计量
with torch.no_grad():
# running_mean 和running_var 形状是(C,), 用squeeze()挤掉多余的维度
running_mean.data = (1 - momentum) * running_mean.data + momentum * mean.squeeze()
running_var.data = (1 - momentum) * running_var.data + momentum * var_unbiased.squeeze()
var = var_biased
else:
mean = running_mean.view(shape)
var = running_var.view(shape)
# 归一化
x_norm = (x - mean) / torch.sqrt(var + eps)
# 仿射变换:y=γx+β
if weight is not None:
w = weight.view(shape)
b = bias.view(shape)
return x_norm * w + b
return x_norm

具体内容如上,关键点我都已经写出来了,大家自行查阅即可。

优化思维时时刻刻都要有,我还写了一个优化版本的,mean、var_unbiased可以同时计算出来:

def manual_batchnorm_v2(x, running_mean, running_var, weight, bias, training, momentum=0.1, eps=1e-5):
dims = [d for d in range(x.dim()) if d != 1]
shape = [1, -1] + [1] * (x.dim() - 2)
if training:
# 计算参与归约的元素总数
n = 1
for d in dims:
n *= x.shape[d]
# 一次性计算均值和方差(无偏方差)
var_unbiased, mean = torch.var_mean(x, dim=dims, keepdim=True, correction=1)
# 转为有偏方差用于归一化
var_biased = var_unbiased * ((n - 1) / n)
with torch.no_grad():
running_mean.data = (1 - momentum) * running_mean.data + momentum * mean.squeeze()
running_var.data = (1 - momentum) * running_var.data + momentum * var_unbiased.squeeze()
var = var_biased
else:
mean = running_mean.view(shape)
var = running_var.view(shape)
x_norm = (x - mean) / torch.sqrt(var + eps)
if weight is not None:
return x_norm * weight.view(shape) + bias.view(shape)
return x_norm

LayerNorm#

LayerNorm不依赖batch统计量,在NLP和Transformer是绝对主力。

torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, ...)
# shape可以是int或者tuple,表示输入最后几个维度的大小,归一化在这些维度进行
# eps,同BatchNorm
# elementwise_affine,是否学习逐元素的weight和bias

调用示例如下:

ln = nn.LayerNorm(512)
x = torch.randn(32, 10, 512)
out = ln.(x) # 形状不变,每个token的512维向量被归一化

我们手动实现的版本:

import torch
import torch.nn as nn
class layernorm(nn.Module):
def __init__(self, embed_size, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(embed_size))
self.beta = nn.Parameter(torch.zeros(embed_size))
self.eps = eps
def forward(self, x):
# correction=0表示有偏估计
var, mean = torch.var_mean(x, dim=-1, keepdim=True, correction=0)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return x_norm * self.gamma + self.beta
if __name__ == "__main__":
x = torch.randn(2, 4, 3)
embed_size = x.size(-1)
official = nn.LayerNorm(embed_size, eps=1e-5)
my = layernorm(embed_size, eps=1e-5)
# 复制相同权重
my.gamma.data = official.weight.data.clone()
my.beta.data = official.bias.data.clone()
diff = (official(x) - my(x)).abs().max().item()
print(f"Max difference: {diff:.2e}")

RMSNorm#

RMSNorm是LayerNorm的一个简化变体,直接减去均值计算的步骤,只除以均方根,在最新的大模型里广泛使用。

torch.nn.RMSNorm(normalized_shape, eps=1e-06, elementwise_affine=True, ...)

使用方法完全同LayerNorm一致。

特性BatchNormLayerNormRMSNorm
归一化轴沿 batch 和空间维度 (每个通道独立)沿特征维度 (每个样本独立)沿特征维度 (每个样本独立)
依赖 batch
可学习参数weight + bias (通道数)weight + bias (归一化形状)weight
适用场景CNN、CVNLP、Transformer大模型
PyTorch 接口BatchNorm1d/2d/3dLayerNormRMSNorm

Triton实现#

下边我们核心先讲解LN和RMS的Triton版本,BN就忽略掉,大家可以自行学习。

归一化的视角#

无论是LN还是RMS,输入都是一个二维矩阵XRN×KX \in R^{N \times K},其中N是所有token的总行数,K是特征维度。

每一行的计算是完全独立的。 对于第n行,算子聚合该行内所有K个元素,算出指定的统计量,再用统计量去逐元素变换该行。

Tip

这是经典的Reduce+Broadcast模式

具体来说,LayerNorm需要两步聚合,分别计算均值和方差;RMSNorm只进行依次平方和聚类。

在GPU上,Reduce是代价最高的部分,Broadcast是element wise操作,因此很明显,归一化操作是带宽密集型

LayerNOrm#

我在这里参考了flash-attenion官方的实现: flash-attention

关于反向传播的代码

实际上,除了前向传播,在写绝大多数算子的时候,也要考虑反向传播,后续我会写一个完整的macroTorch,来完整的学习是如何进行前向传播与反向传播的。

关于自动调优,可以写一个函数来囊括进去常用的内容。 同时我在这里保存了mean和rstd,这样能够给未来的反向传播直接复用,避免重新进行规约计算,因此可以看到我显式的进行了两次保存。

并且还有一个设计上的巧思,即步长参数化,而不是假设内存严格连续,这可以让kernel实现更加复杂的张量布局。

import torch
import torch.nn as nn
import triton
import triton.language as tl
from triton.testing import do_bench
def get_autotune_configs():
warp_size = 32
max_threads_per_block = 1024
configs = []
for num_warps in [1,2, 4, 8, 16, 32]:
if num_warps * warp_size <= max_threads_per_block:
configs.append(triton.Config({}, num_warps=num_warps))
return configs
@triton.autotune(
configs=get_autotune_configs(),
key=["N"],
)
@triton.jit
def layer_norm_fwd_kernel(
X_ptr, Y_ptr, W_ptr, B_ptr, Mean_ptr, Rstd_ptr,
stride_x_row, stride_y_row, # 传入行步长即可灵活索引,Triton编程常见模式
N, eps, BLOCK_N: tl.constexpr
):
# 每个program处理1行
row_idx = tl.program_id(0)
X_row_ptr = X_ptr + row_idx * stride_x_row
Y_row_ptr = Y_ptr + row_idx * stride_y_row
cols = tl.arange(0, BLOCK_N)
mask = cols < N
# 加载该行的所有元素,在计算中提升为FP32
x = tl.load(X_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W_ptr + cols, mask=mask).to(tl.float32)
b = tl.load(B_ptr + cols, mask=mask).to(tl.float32)
# 规约:均值 and 方差
mean = tl.sum(x, axis=0) / N
tl.store(Mean_ptr + row_idx, mean)
x_bar = tl.where(mask, x - mean, 0.0)
var = tl.sum(x_bar * x_bar, axis=0) / N
rstd = 1.0 / tl.sqrt(var + eps)
tl.store(Rstd_ptr + row_idx, rstd)
# 广播:归一化+仿射变换
y = (x - mean) * rstd * w + b
tl.store(Y_row_ptr + cols, y, mask=mask)
def layer_norm_fwd(x, weight, bias=None, eps=1e-5):
M, N = x.shape
y = torch.empty_like(x)
mean = torch.empty(M, device=x.device, dtype=torch.float32)
rstd = torch.empty(M, device=x.device, dtype=torch.float32)
if bias is None:
bias = torch.zeros(N, device=x.device, dtype=weight.dtype)
if weight is None:
weight = torch.ones(N, device=x.device, dtype=x.dtype)
BLOCK_N = triton.next_power_of_2(N)
layer_norm_fwd_kernel[(M,)](
x, y, weight, bias, mean, rstd,
x.stride(0), y.stride(0),
N, eps, BLOCK_N=BLOCK_N
)
return y, mean, rstd
def test_correctness(shapes=[(128, 256), (512, 1024)]):
for M, N in shapes:
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
weight = torch.randn(N, device='cuda', dtype=torch.float32)
bias = torch.randn(N, device='cuda', dtype=torch.float32)
eps = 1e-5
ln = nn.LayerNorm(N, eps=eps).to('cuda')
ln.weight.data = weight
ln.bias.data = bias
y_ref = ln(x)
y_tri, _, _ = layer_norm_fwd(x, weight, bias, eps)
max_diff = (y_tri - y_ref).abs().max().item()
print(f"Shape ({M}, {N}): max diff = {max_diff:.6e}")
bench_perf_report = None
try:
from triton.testing import Benchmark, perf_report
@perf_report(
Benchmark(
x_names=["N"],
x_vals=[256, 512, 1024, 2048, 4096, 8192],
line_arg="provider",
line_vals=["triton", "pytorch"],
line_names=["Triton", "PyTorch"],
styles=[("blue", "-"), ("red", "-")],
ylabel="Latency (ms)",
plot_name="LayerNorm Fwd Performance",
args={"M": 1024, "eps": 1e-5, "dtype": torch.float32},
)
)
def _bench_perf_report(M, N, eps, dtype, provider):
device = 'cuda'
x = torch.randn(M, N, device=device, dtype=dtype)
weight = torch.randn(N, device=device, dtype=dtype)
bias = torch.randn(N, device=device, dtype=dtype)
if provider == "triton":
def run():
return layer_norm_fwd(x, weight, bias, eps)
else:
ln = nn.LayerNorm(N, eps=eps).to(device)
ln.weight.data = weight
ln.bias.data = bias
def run():
return ln(x)
return do_bench(run, quantiles=[0.5, 0.2, 0.8])
bench_perf_report = _bench_perf_report
except ImportError:
print("当前 Triton 版本不支持 perf_report,跳过高阶绘图功能\n")
if __name__ == "__main__":
test_correctness()
if bench_perf_report is not None:
bench_perf_report.run(show_plots=True, print_data=True, save_path="./layer_norm_fwd_perf.png")

性能对比如下:

LN在Triton和Pytorch的性能表现
LN在Triton和Pytorch的性能表现

RMSNorm#

LayerNorm需要计算两个统计量:mean和std。 mean就得涉及到规约,而RMSNorm只计算一个:

rms = sqrt(E[x^2] + eps)
y = x / rms * γ

核心代码如下:

@triton.autotune(
configs=autotune_configs(),
key=["N"],
)
@triton.jit
def rms_norm_fwd_kernel(
X_ptr, Y_ptr, W_ptr,
stride_x_row, stride_y_row,
N, eps,
BLOCK_N: tl.constexpr,
):
row_idx = tl.program_id(0)
X_row_ptr = X_ptr + row_idx * stride_x_row
Y_row_ptr = Y_ptr + row_idx * stride_y_row
cols = tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(X_row_ptr + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# RMSNorm: rstd = rsqrt(mean(x^2) + eps)
x2 = x * x
mean_x2 = tl.sum(x2, axis=0) / N
rstd = tl.rsqrt(mean_x2 + eps)
# 广播:缩放+仿射变换
y = x * rstd * w
tl.store(Y_row_ptr + cols, y, mask=mask)
def rms_norm_fwd(x, weight, eps=1e-5):
M, N = x.shape
y = torch.empty_like(x)
if weight is None:
weight = torch.ones(N, device=x.device, dtype=x.dtype)
BLOCK_N = triton.next_power_of_2(N)
rms_norm_fwd_kernel[(M,)](
x, y, weight,
x.stride(0), y.stride(0),
N, eps,
BLOCK_N=BLOCK_N,
)
return y

RMS的Triton与Pytorch性能对比图
RMS的Triton与Pytorch性能对比图
我在特征维度从256到8192进行了测试,Triton版本始终优于Pytorch,并且自动调优保证了在不同尺寸下的均衡表现。

RMS在FP16下的优势更加明显,因为Triton对混合精度路径的控制更加精细,避免了不必要的类型转换开销。

====================================================================================================
RMSNorm Performance: dtype=torch.float16
====================================================================================================
M N Triton(ms) PyTorch (ms) my vs PT
----------------------------------------------------------------------------------------------------
128 256 0.003816 0.021189 5.55x
128 512 0.004489 0.021973 4.89x
128 1024 0.004901 0.023449 4.78x
128 2048 0.005038 0.026430 5.25x
128 4096 0.006675 0.032510 4.87x
128 8192 0.008356 0.038082 4.56x
512 256 0.004816 0.022823 4.74x
512 512 0.005085 0.025100 4.94x
512 1024 0.006658 0.029662 4.45x
512 2048 0.008573 0.038389 4.48x
512 4096 0.011734 0.058861 5.02x
512 8192 0.022541 0.093368 4.14x
1024 256 0.005728 0.025211 4.40x
1024 512 0.006578 0.029643 4.51x
1024 1024 0.009162 0.037848 4.13x
1024 2048 0.011643 0.057440 4.93x
1024 4096 0.020953 0.093154 4.45x
1024 8192 0.041657 0.170544 4.09x
2048 256 0.006578 0.029480 4.48x
2048 512 0.007969 0.037949 4.76x
2048 1024 0.011659 0.057081 4.90x
2048 2048 0.021078 0.092297 4.38x
2048 4096 0.041632 0.169116 4.06x
2048 8192 0.077019 0.681771 8.85x
4096 256 0.007757 0.037739 4.87x
4096 512 0.012584 0.057165 4.54x
4096 1024 0.020957 0.091967 4.39x
4096 2048 0.040349 0.168362 4.17x
4096 4096 0.077127 0.669127 8.68x
4096 8192 0.149073 1.469917 9.86x
====================================================================================================
RMSNorm Performance: dtype=torch.float32
====================================================================================================
M N Triton(ms) PyTorch (ms) my vs PT
----------------------------------------------------------------------------------------------------
128 256 0.004033 0.016027 3.97x
128 512 0.004791 0.016718 3.49x
128 1024 0.005044 0.018569 3.68x
128 2048 0.006306 0.021332 3.38x
128 4096 0.009204 0.026772 2.91x
128 8192 0.013090 0.030454 2.33x
512 256 0.005348 0.017874 3.34x
512 512 0.006673 0.019975 2.99x
512 1024 0.008565 0.023732 2.77x
512 2048 0.012434 0.031180 2.51x
512 4096 0.020615 0.048329 2.34x
512 8192 0.043017 0.077455 1.80x
1024 256 0.006282 0.019679 3.13x
1024 512 0.008333 0.023323 2.80x
1024 1024 0.012538 0.030153 2.40x
1024 2048 0.020061 0.046295 2.31x
1024 4096 0.043250 0.076708 1.77x
1024 8192 0.079363 0.145954 1.84x
2048 256 0.008264 0.023329 2.82x
2048 512 0.011833 0.030187 2.55x
2048 1024 0.023103 0.046089 1.99x
2048 2048 0.042146 0.075980 1.80x
2048 4096 0.078298 0.144876 1.85x
2048 8192 0.150213 0.479396 3.19x
4096 256 0.012040 0.030257 2.51x
4096 512 0.020640 0.045097 2.18x
4096 1024 0.040291 0.075542 1.87x
4096 2048 0.079232 0.144678 1.83x
4096 4096 0.150072 0.465174 3.10x
4096 8192 0.296848 1.028277 3.46x

FP32下加速比在2-4倍之间,FP16下可达4-9倍。

CUDA实现#

二者对于每一行N,都会遍历该行的所有元素,计算统计量(平方和、均值和方差等),并且使用这些统计量做归一化变换。

实际上,Norm的核心分为两个阶段:

阶段 1 — Reduce(归约):
每一行的 K 个元素 → 汇总成 1 个统计量(标量)
例如:RMSNorm 把一行 8192 个 float 聚合成一个平方和
LayerNorm 把一行 8192 个 float 先聚合成均值,再聚合成方差
阶段 2 — Broadcast(广播):
用这 1 个标量去变换该行的每一个元素
y[n][k] = f(x[n][k], stat[n])

用 C++ 伪代码表达就是:

for (int n = 0; n < N; n++) { // N 行,行间独立
float stat = 0;
for (int k = 0; k < K; k++) { // Reduce: K 个元素 → 1 个值
stat += compute(x[n][k]);
}
stat = finalize(stat); // 例如 rsqrt(stat/K + eps)
for (int k = 0; k < K; k++) { // Broadcast: 1 个值 → K 个元素
y[n][k] = transform(x[n][k], stat, gamma[k], beta[k]);
}
}
Important

所有归一化的算子,本质上都是Reduce + Broadcast

支持与分享

如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!

赞助
LeetGPU习题07:Norm系列代码实现
https://dlog.com.cn/posts/leetgpu07/norm/
作者
杜子源
发布于
2026-06-12
许可协议
CC BY-NC-SA 4.0
Profile Image of the Author
杜子源
都是风景,幸会
公告
请狠狠地打赏我,打赏一次,爆更一篇!!
音乐
封面

音乐

暂未播放

0:00 0:00
暂无歌词
分类
标签
站点统计
文章
28
分类
8
标签
11
总字数
68,959
运行时长
0
最后活动
0 天前

目录