LeetGPU习题05:Softmax优化详解

7481 字
37 分钟
LeetGPU习题05:Softmax优化详解
2026-05-11

什么是Softmax?#

1. 数学定义#

Softmax数学定义
Softmax数学定义
给定一个实数向量 z=[z1,z2,,zK]\mathbf{z} = [z_1, z_2, \dots, z_K],Softmax 函数会将其转换为一个概率分布 p=[p1,p2,,pK]\mathbf{p} = [p_1, p_2, \dots, p_K],其中:

pi=ezij=1Kezjp_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}

这里 ee 是自然对数的底数(约等于 2.718)。这个公式保证了:

  • pi(0,1)p_i \in (0, 1)(每个值都在 0 到 1 之间)
  • i=1Kpi=1\sum_{i=1}^{K} p_i = 1(所有值加起来等于 1)

这两个性质正好满足概率分布的数学要求,因此 softmax 的输出可以直接解释为多项分布的概率

和Softmax对应的是argmax,argmax就是我们通常理解的max。

import torch
a = torch.randn(10)
b = torch.max(a)
print(a)
print(b.item())

但是argmax在求导方面存在诸多问题,例如不可导等等。

2. 历史渊源#

在现在的Transformer架构中,softmax无处不在,但是它最早是哪里来的呢?

1868年,玻尔兹曼在其奠基性统计力学有关玻尔兹曼分布的论文中提出了Softmax函数。 1902年,Gibbs在《统计力学基础原理》中对Softmax进行了形式化与推广。

实际上,玻尔兹曼分布定义了位于不同能量级的粒子数量的概率分布,该分布期望:

在能量与粒子数不变的前提下,能量级状态数量最大化。

在之后的逻辑回归中,研究者需要对具有多个无序类别的结果进行建模,因此它们将线性预测器的输出自然地对数-几率比参数化,由此推导出的概率公式就是Softmax。

在1989年,John S.bridle在论文中首次把这种网络输出指数化再归一化的层命名为Softmax,以区别于选择最大值的argmax

3. 为什么是指数函数?#

使用指数函数主要是两个原因:

  • 与交叉熵结合。 当我们用 Softmax 的输出作为预测概率,然后搭配 交叉熵损失 函数时,损失函数对某个未归一化输出 zkz_k 的梯度会变得极其简洁:
Lzk=pkyk\frac{\partial L}{\partial z_k} = p_k - y_k

(其中 yky_k 是真实标签的 one-hot 编码)这意味着反向传播的梯度信号就是“预测值”与“真实值”的误差,这种线性减法形式使得训练极为稳定,避免了梯度消失或爆炸。

  • 广义线性模型的规范链接。 在指数分布族的框架下,所有的概率分布(正态分布、泊松分布)都可以归纳为指数分布簇。 如果我们面对一个多分类问题,数学推导证明链接原始线性输出与最终概率分布的最优且唯一的规范链接函数,恰好就是Softmax。 这意味着使用 Softmax 和交叉熵时,我们实际上是在正确指定的统计模型下进行最大似然估计。

4. 继续深挖Softmax#

在神经网络分类层中,最后的全连接层输出通常被称为 Logits(未归一化的对数概率)。Logits 的取值范围是整个实数域 (,)(-\infty, \infty),而概率必须位于 [0, 1] 且和为 1。Softmax 在数学上实现了一个光滑的、满射的映射:

RKΔK1\mathbb{R}^K \mapsto \Delta^{K-1}

它将无限延伸的实数空间“压缩”成一个 (K-1) 维的概率单纯形。

注意 Softmax 的分母 log(ezj)\log\left(\sum e^{z_j}\right) 被称为 Log-Sum-Exp 函数,它是 max 函数的一个光滑近似:

max(z)logezj\max(\mathbf{z}) \approx \log\sum e^{z_j}

而 Softmax 恰好就是 Log-Sum-Exp 函数对每个分量 ziz_i 的偏导数。这揭示了 Softmax 的另一个数学身份:它是 max\max 算子的光滑梯度。

当存在一个输入 zjz_j 显著大于其他元素时,ezje^{z_j} 会在分母中占绝对主导地位,对应的 pjp_j 会无限趋近于 1,而其他 pip_i 趋近于 0。Softmax 就退化成一个近乎 one-hot 的分布,表现得像 argmax。反之,当输入值差异很小时,它会给出一个接近均分的概率分布。通过引入 温度参数 TT

pi=ezi/Tjezj/Tp_i = \frac{e^{z_i / T}}{\sum_{j} e^{z_j / T}}

我们可以控制输出的“软硬程度”:

  • T0T \to 0:趋近于 argmax,分布极尖锐。
  • TT \to \infty:趋近于均匀分布,极为平滑。

5. 为什么Transformer中要用Softmax#

在 Transformer 架构中,Softmax更是自注意力机制的核心组件。

很多市面上的八股都会告诉大家Softmax的若干作用,它的公式如下:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

注意力机制本质上是信息检索的连续化。 设想一个包含键值对 {(kj,vj)}\{(k_j, v_j)\} 的数据库,给定查询 qq 后,最直接的做法是找到相似度最高的键,并输出其对应的值。

这等价于对相似度分数执行 argmax 后做点积。但是argmax 分段常数并且不可导,阻碍了基于梯度的端到端训练。

Softmax 正是 argmax 的一种光滑化、处处可导的近似。 它以连续概率的形式分配注意力。 相似度越高的键,获得指数级更大的权重,但所有键都会保留非零梯度通路。 这种“软查找”既保留了选择性聚合的能力,又让误差信号能无阻碍地反向传播到查询和键的表示学习中。

原始的注意力分数矩阵 QKQK^\top 中,元素可以是任意实数,数值范围不受约束。 如果直接用这些裸分去加权 VV 矩阵,随着序列长度增加或网络加深,输出向量的范数会不受控制地增长或衰减,导致训练不稳定。

Softmax 通过分母 jesij\sum_j e^{s_{ij}} 将所有分数映射到 (0,1) 之间,且每行的和为 1。于是注意力输出退化为 VV 中各行向量的一种凸组合,即加权平均值。

由于凸组合的范数严格被输入范数的最大值所限制,无论序列有多长,注意力层的输出量级始终保持在稳定区间。这是在深层架构中堆叠数十层注意力而无需额外范数约束的数学基础。

One more thing!

既然只需要加权平均,那更简单的 L1L_1 归一化(L1范数)为何不被采用?

原因在于长序列场景下的信噪分离需求。 线性归一化会将注意力分散到大量低相关度的 token 上,每个 token 都分得微小但不可忽略的权重。这种“平均主义”会稀释核心信息的表征,对推理和长程依赖非常不利。

指数函数 exe^x 作为凸的、严格正且增长迅速的激活函数,天然具备对比度放大的能力。 当分数间存在微小差异时,指数运算会显著拉大权重差距. 高相似度的 token 获得压倒性权重,低相似的 token 权重急剧趋近于零。 效果上,Softmax 自动实现了自适应稀疏性。 模型无需预先指定哪几个 token 应该被“看见”,而是通过指数竞争机制动态地让少数关键元素主导输出,同时屏蔽其余噪声。

One more more thing

为什么要引入缩放因子 1/dk1/\sqrt{d_k}

我们来看这个例子:

想象一个句子有三个token:[I love you]。 当前哟啊计算出第一个token的注意力,模型算出了它与三个key的点积分数:[2, 5, 10]

如果我们直接送进Softmax,输出则是[0.0000015%,0.00003%,99.99997%]

这几乎就是一个one-hot向量。第三个token完全拿走了所有的注意力,前两个token的梯度则几乎为0。 这会导致输出对输入的任何微小变化都毫无反应。

那我们现在加入缩放因子,假设dk=64d_k = 64,则缩放后的分数变为:[0.25, 0.625, 2.5]

这个时候输出就变成了:[7.5%, 10.9%, 81.6%] 虽然仍然是第三个值最大,但是剩下两个仍然保留了可观的份额,输出不再是one-hot。

One more more more thing?

为什么恰好是这个缩放因子?

实际上,对于dkd_k维度的随机向量,假设每个分量独立且方差为1,点积Q×KQ \times K的方差恰好等于dkd_k。因此点积标准差恰好是这个,也就是将点积的方差重新固定为 1

  • dk\sqrt{d_k} 缩放让分数分布恢复成“常温”状态,恰好落在 Softmax 既能拉开差距、又不至饱和的区间。
  • 如果除以更小的数(如 0.5dk0.5\sqrt{d_k}),分布仍偏宽,模型会重新倾向 one-hot;
  • 如果除以更大的数(如 2dk2\sqrt{d_k}),分布收得过窄,注意力会变得过于均匀,变成“什么都注意,什么都记不住”。

Pytorch版本实现#

1. naive版本#

好了,我们按照数学公式,直接来写吧:

def naive_softmax(x, dim=-1):
exp_x = torch.exp(x)
return exp_x / exp_x.sum(dim=dim, keepdim=True)

OK, 带入[1, 100, 100000]试一下,你试试呢~

(mycuda) 012 python psoftmax.py
tensor([0., nan, nan])
(mycuda) 012

这段代码看起来没什么问题,但是只要 x 中出现较大的元素exp就会直接溢出为 inf,导致整个 Softmax 输出变成一堆 nan。

2. Safe Softmax#

为了解决这个问题,论文中提出了它的解决方案:

Safe Softmax
Safe Softmax

我来给大家翻译一下。

  1. 找到这个向量/矩阵中要求行的最大值M。
  2. 将所有含有exe^x的式子变换为exMe^{x-M}
  3. 带入原式子

完整的代码如下,其中有注释,不再赘述:

import torch
def softmax(x, dim=-1):
"""
softmax(z)_i = exp(z_i) / sum_j exp(z_j)
"""
# 第 1 步:求最大值
# 沿目标维度 dim 求最大值,keepdim=True 保持原维度数,
# 便于后续广播减法。x_max 的形状会在 dim 维度变成 1,其他维不变。
x_max = x.max(dim=dim, keepdim=True)[0] # [0] 取出最大值
# 第 2 步:平移输入数据,减去最大值
# 数学上等价于将分子分母同除以 exp(x_max),防止指数爆炸。
# 因为 x_max 是沿 dim 维的最大值,x - x_max 的所有元素 ≤ 0,
safe_x = x - x_max # 广播机制
# 第 3 步:指数计算
# 对平移后的张量逐元素求 exp。
exp_x = torch.exp(safe_x)
# 第 4 步:求和
# 沿 dim 维度对 exp_x 求和,同样 keepdim=True 保持维度以便广播。
sum_exp = exp_x.sum(dim=dim, keepdim=True)
# 第 5 步:归一化得到最终概率
output = exp_x / sum_exp
return output
def naive_softmax(x, dim=-1):
exp_x = torch.exp(x)
return exp_x / exp_x.sum(dim=dim, keepdim=True)
if __name__ == "__main__":
# 一维向量
x1 = torch.tensor([1.0, 2.0, 3.0])
p1 = softmax(x1)
print("一维 softmax:\n",p1)
print(f"Torch Softmax:\n{torch.softmax(x1, dim=-1)}")
print("求和 =", p1.sum().item()) # 应为 1
# 二维矩阵
x2 = torch.tensor([[1., 2., 3.],
[4., 5., 6.]])
p2 = softmax(x2, dim=-1) # 对每一行做 softmax
print("\n二维 softmax:\n",p2)
print(f"Torch Softmax:\n{torch.softmax(x2, dim=-1)}")
print("每行求和 =", p2.sum(dim=-1)) # 结果 [1, 1]

Triton版本实现#

1. Safe Softmax版本实现#

我们已经学习过了element wise和reduce,大家可以仔细分析一下,softmax的运算都分为哪些呢?

许多官方库的实现都是采用Safe Softmax
许多官方库的实现都是采用Safe Softmax

从论文中,我们将其进行分解为Element wise和reduce试一下?

回顾数值稳定的 Safe Softmax 公式:

yi=eximjexjm,m=maxkxky_i = \frac{e^{x_i - m}}{\sum_{j} e^{x_j - m}}, \quad m = \max_k x_k

它由以下 5个子运算 组成,可以按模式归为两类:

步骤运算类型说明
① 求最大值 m=maxxm = \max xreduction跨元素归约需要在整行/整向量上进行比较归约,产生一个标量。
② 平移 xi=ximx'_i = x_i - melement-wise逐元素每个元素独立减去同一个标量,完全并行。
③ 指数 exie^{x'_i}element-wise逐元素每个元素独立计算指数函数。
④ 求和 d=exid = \sum e^{x'_i}reduction跨元素归约将指数结果累加为一个标量。
⑤ 归一化 yi=exi/dy_i = e^{x'_i} / delement-wise逐元素每个指数除以上述标量,得到最终概率。

论文指出,几乎所有深度学习框架使用的 Safe Softmax 都执行 三次遍历

  1. 第一遍求最大值 mm
  2. 第二遍求指数和 dd
  3. 第三遍计算输出 yiy_i

我们用triton来实现一下

@triton.jit
def kernel_softmax_fuse(
x_ptr, x_row_stride,
y_ptr, y_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
idx = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - tl.max(x))
eps = float(1e-9)
x /= tl.maximum(tl.sum(x), eps)
tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_fuse(x):
n_rows, n_cols = x.shape
y = torch.empty_like(x)
kernel_softmax_fuse[[n_rows]](
x, x.stride(0),
y, y.stride(0),
n_cols,
BLOCK_SIZE=triton.next_power_of_2(n_cols),
num_warps=32,
)
return y

和原版不同的是,它只做 一次加载,然后直接在片上完成:

  • tl.max(x) → 求最大值
  • x - tl.max(x) → 平移
  • tl.exp(...) → 指数
  • tl.sum(x) → 求和(归约)
  • tl.maximum(..., eps) → 防除零
  • 除法和写回

对于参数来说,

@triton.jit
def kernel_softmax_fuse(
x_ptr, # 输入张量的基地址
x_row_stride, # 输入行步幅
y_ptr, # 输出张量的基地址
y_row_stride, # 输出行步幅
n_cols, # 列数(每行的元素个数)
BLOCK_SIZE: tl.constexpr, # 编译时常量,决定一次加载的元素数
):

Triton 是显式内存编程模型,kernel 接收到的是原始指针(即张量首元素的地址),而不是 tensor 对象。我们需要手动计算偏移来访问特定行的数据。

stride 是连续两行起点之间的元素个数(或者说偏移量)。在 PyTorch 中,x.stride(0) 就给出了这个值。(我们点击stride就可以查看函数定义)

def stride(self, dim: None = None) -> tuple[_int, ...]:
r"""
stride(dim) -> tuple or int
Returns the stride of :attr:`self` tensor.
Stride is the jump necessary to go from one element to the next one in the
specified dimension :attr:`dim`. A tuple of all strides is returned when no
argument is passed in. Otherwise, an integer value is returned as the stride in
the particular dimension :attr:`dim`.
Args:
dim (int, optional): the desired dimension in which stride is required
Example::
>>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
>>> x.stride()
(5, 1)
>>> x.stride(0)
5
>>> x.stride(-1)
1
"""
  • 对于连续内存的矩阵 shape [n_rows, n_cols]stride(0) 通常等于 n_cols
  • 但如果矩阵是转置视图、permute 或切片后的结果,stride 可能不等于列数。此时内存不连续,两行之间会有间隔。

普通 Safe Softmax 每次都要遍历全局内存三次:读 max、读并累加 sum、读并写归一化结果。这三个遍历会严重消耗内存带宽。

Triton这个版本的思路是:

  • 如果一整行数据能放入 寄存器/共享内存(即 n_cols 不大),我们就可以只读一次、只写一次
  • 在片上完成所有计算,内存访问被压缩到理论极限(2 次:1 读 + 1 写)。
  • 通过 tl.maxtl.sum 这些高效的归约原语(内部由 warp shuffle 实现),不必把中间结果写回全局内存。

但这也带来了限制

  • BLOCK_SIZE 必须容纳整行。当 n_cols 很大(例如 100k 长序列)时,硬件没有足够的寄存器/共享内存给每个线程,kernel 会直接无法启动。这就解释了为什么需要后面的 tileonline 版本。

2. Tile版本实现#

Fuse 版本的核心理念是 整行一次性加载,在寄存器/共享内存中完成所有运算,然后一次写回。

它做到了最低的全局内存访问:每行 1 读 + 1 写

但是它的致命问题是片上容量。 假设某一行非常大,例如1_000_000_000这么大的一行,那么Kernel只能去在HBM上进行交互,会让整个效率骤降。

因此我们需要设计出一种能够处理一行数据非常大的方案。

一个自然的想法就是分块处理,这就是我们提出的Tile版本。

Tile版本就是按照把原本的数据分块进行读取:

  1. 第一遍:遍历所有块,计算全局最大值 mm
  2. 第二遍:用已知的 mm,再次遍历所有块,累加得到全局归一化分母 dd
  3. 第三遍:用 m,dm, d,再次遍历所有块,计算最终输出并写回。

代码实现:

特别鸣谢 星合の空
@triton.jit
def kernel_softmax_tile(
x_ptr, x_row_stride,
y_ptr, y_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
CACHE_OPT: tl.constexpr,
):
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
mm = tl.maximum(mm, x)
mm = tl.max(mm)
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if CACHE_OPT:
for i in range(tl.cdiv(n_cols, BLOCK_SIZE) - 1, -1, -1):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm)
ss += x
else:
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm)
ss += x
ss = tl.sum(ss)
eps = float(1e-9)
ss = tl.maximum(ss, eps)
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm) / ss
tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_tile(x, cache_opt=True):
n_rows, n_cols = x.shape
y = torch.empty_like(x)
kernel_softmax_tile[[n_rows]](
x, x.stride(0),
y, y.stride(0),
n_cols,
BLOCK_SIZE=2**14,
CACHE_OPT=cache_opt,
num_warps=32,
)
return y

2.1. 分块找最大值#

我们维护一个长度为BLOCK_SIZE的向量,并且初始化为-inf,用来保存当前所有线程各自观察到的最大值。之后不断迭代这个mm。 具体示意图如下:

Tile找最大值示意图
Tile找最大值示意图

为什么这里先用向量,后规约呢?

为了避免每次加载都要跨线程同步,只有在最后循环结束后做一次最终规约。

之后我们需要同样计算全局和,不过根据代码可以发现,我们此时选择倒序。 正序还是倒序无关数学结果,但是倒序可以提升缓存命中率。 循环结束后,和max一样,也是最后才进行规约。

为什么倒序会更快?
  • 第一遍扫描结束的那一刻,最后几个块的输入数据很可能还在GPU的 L2 缓存里。
  • 第二遍若正序开始,会先加载首部块,此时尾部块可能已被逐出,再次加载发生缓存缺失。
  • 倒序则直接从尾部开始,这些数据大概率命中缓存,减少 HBM 访问,提升带宽利用率。
GPU编程经典技巧:顺序写,倒序读,来捕获缓存中的驻留数据。

Sum和Norm的示意图
Sum和Norm的示意图

最终遍历所有块,将结果写入输出张量对应的行。

实际上经过测试,在小规模数据中,Tile版本速度明显效率Fused速度,但是在大规模数据中,性能有明显提升。

Tile速度慢的主要原因是因为访存次数的增多,会导致理论速度上限下降,但是这也是没办法的,毕竟Fused无法处理长行。 因此混合场景往往是: 对于小规模数据使用Fuse,对于大规模数据使用Tile。

实际上后边还有更加先进的版本

性能对比测试我放在最后。

3. Online版本实现#

Tile相比较Fused来说,解决了行规模大计算缓慢的问题,但是又增加了访存次数,那有没有更好的优化方式呢?

Safe Softmax必须先找到全局最大值,才能够计算所有元素的指数和。

但是实际上我们完全可以采用滑动窗口的思想,在遍历中维护一个运行时最大值,并且动态将已经计算的归一化项调整到新的最大值基准,从而将“找最大值”和“算指数和”合并为单次遍历。

Online normalization算法伪代码
Online normalization算法伪代码

3.1. Online计算过程#

符号含义
xjx_j输入向量的第jj个元素
mj1m_{j-1}遍历到第jj个元素之前,前j1j-1个元素的运行时最大值
mjm_j遍历到第jj个元素之后,前jj个元素的运行时最大值
dj1d_{j-1}遍历到第jj个元素之前,前j1j-1个元素mj1m_{j-1}为基准的指数和(动态归一化项)
djd_j遍历到第jj个元素之后,前jj个元素mjm_j为基准的指数和
VV输入向量的总长度
mVm_V遍历完成后,整个输入向量的全局最大值
dVd_V遍历完成后,整个输入向量以全局最大值为基准的总指数和(Softmax的归一化分母)

遍历输入向量的每个元素xjx_j时,同时更新运行时最大值和动态归一化项,仅需单次遍历即可得到最终的mVm_VdVd_V

mj=max(mj1, xj)dj=dj1emj1mj+exjmj\boxed{ \begin{align} m_j &= \max\left(m_{j-1},\ x_j\right) \tag{1} \\ d_j &= d_{j-1} \cdot e^{m_{j-1} - m_j} + e^{x_j - m_j} \tag{2} \end{align} }

作用:记录遍历过程中遇到的最大元素,保证后续指数计算不会溢出。

  • 逻辑极其简单:每次取「之前的最大值」和「当前元素」的较大者
  • 数值稳定性:mjm_j始终是单调非递减的,永远不会溢出或下溢

公式可以拆分为两部分理解:

dj=dj1emj1mj旧基准和的调整项+exjmj当前元素的指数项d_j = \underbrace{d_{j-1} \cdot e^{m_{j-1} - m_j}}_{\text{旧基准和的调整项}} + \underbrace{e^{x_j - m_j}}_{\text{当前元素的指数项}}

根据当前元素xjx_j是否超过之前的最大值,公式会自动简化为两种逻辑:

情况1:xjmj1x_j \leq m_{j-1}(最大值不变)
  • 代入公式(1):mj=mj1m_j = m_{j-1}
  • 调整因子:emj1mj=e0=1e^{m_{j-1} - m_j} = e^0 = 1
  • 公式(2)简化为:
dj=dj1+exjmj1 d_j = d_{j-1} + e^{x_j - m_{j-1}}
  • 最大值没有变化,不需要调整之前的总和,直接把当前元素的指数值累加到归一化项中即可。
情况2:xj>mj1x_j > m_{j-1}(发现更大值)
  • 代入公式(1):mj=xjm_j = x_j
  • 调整因子:emj1xj<1e^{m_{j-1} - x_j} < 1(因为xj>mj1x_j > m_{j-1},指数为负)
  • 当前元素的指数项:exjmj=exjxj=1e^{x_j - m_j} = e^{x_j - x_j} = 1
  • 公式(2)简化为:
dj=dj1emj1xj+1 d_j = d_{j-1} \cdot e^{m_{j-1} - x_j} + 1
  • 发现了更大的元素,需要把之前所有元素的指数和按比例缩小,再加上当前元素的指数值1。
def online_softmax(x: np.ndarray) -> np.ndarray:
m = -np.inf # 运行时最大值
d = 0.0 # 动态归一化项
# 第1次遍历:同时计算全局最大值m和归一化和d(只读输入)
for xi in x:
new_m = max(m, xi)
# 核心:动态调整d到新的最大值基准
d = d * np.exp(m - new_m) + np.exp(xi - new_m)
m = new_m
# 第2次遍历:计算最终概率(读输入+写输出y)
y = np.exp(x - m) / d
return y

3.2. Triton版本实现#

@triton.jit
def kernel_softmax_online(
x_ptr, x_row_stride,
y_ptr, y_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
CACHE_OPT: tl.constexpr,
):
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
mm_new = tl.maximum(mm, x)
if i: # 第 1 轮不需要,且容易整出 nan
ss *= tl.exp(mm - mm_new)
x = tl.exp(x - mm_new)
ss += tl.where(idx < n_cols, x, 0.0)
mm = mm_new
mm_new = tl.max(mm)
ss *= tl.exp(mm - mm_new)
ss = tl.sum(ss)
mm = mm_new
eps = float(1e-9)
ss = tl.maximum(ss, eps)
if CACHE_OPT:
for i in range(tl.cdiv(n_cols, BLOCK_SIZE) - 1, -1, -1):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm) / ss
tl.store(y_ptr + idx, x, mask=idx < n_cols)
else:
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
x = tl.exp(x - mm) / ss
tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_online(x, cache_opt=True):
n_rows, n_cols = x.shape
y = torch.empty_like(x)
kernel_softmax_online[[n_rows]](
x, x.stride(0),
y, y.stride(0),
n_cols,
BLOCK_SIZE=2**12,
CACHE_OPT=cache_opt,
num_warps=32,
)
return y

我们来具体看一下是如何使用Triton来实现的。 这里每一个Program处理输入张量中的一行,这也是最自然的并行方式,因为Softmax是按照行来独立计算的。

我们首先看内核参数列表:

def kernel_softmax_online(
x_ptr, x_row_stride, # 输入张量的指针和行步长
y_ptr, y_row_stride, # 输出张量的指针和行步长
n_cols, # 每行的元素个数
BLOCK_SIZE: tl.constexpr, # 每个块处理的元素个数
CACHE_OPT: tl.constexpr, # 是否启用缓存优化
):

这里我们并没有直接传入输入张量的形状,而是传递了指针和行步长。 这是因为输入张量可能并不是连续的(切片或转置)。传入步长可以正确计算每行的起始地址。

之后就正常计算输入行和输出行的起始地址。

再然后就是关于局部变量

mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
为什么这里是向量而不是标量呢?

在Triton中,当你定义一个形状为BLOCK的张量时,编译器会将这个张量的元素分配给Program中的所有线程。每个线程负责处理一个或者多个元素。假设现在BLOCK_SIZE=4096,num_warps=32,那么当前Program总共有1024个线程,每个线程处理4个元素。

我们把Online Softmax的算法和Triton的编程模型结合起来,看看for循环的主要作用。

为方便理解,我们假设要处理的这一行数据长度 n_cols = 16,我们设定的 BLOCK_SIZE = 8,这样循环就需要迭代 triton.cdiv(16, 8) = 2 次。

初始化#

# 1. 计算当前 Program 要处理的那一行数据的起始内存地址。
row_idx = tl.program_id(0)
x_ptr += row_idx * x_row_stride
y_ptr += row_idx * y_row_stride
# 2. 初始化运行时统计量。
# mm: 一个长度为 BLOCK_SIZE 的张量,每个元素都被初始化为负无穷。
# ss: 一个长度为 BLOCK_SIZE 的张量,每个元素都被初始化为 0.0。
# 为什么是张量?因为此时我们还没拿到全局最大值,需要用一个和块大小一样的“容器”
# 来维护块内每个元素的运行时最大值和指数和。
mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

第1次迭代#

for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
# 第一轮 (i=0):
# 第1步:加载数据
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
# idx 是一个包含 [0, 1, 2, 3, 4, 5, 6, 7] 的张量。
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
# x 也是一个 [8] 的张量,它的 8 个元素由 Program 内的线程共同持有。假设加载的值是 [3, 1, 4, 1, 5, 9, 2, 6]。
# 第2步:更新运行时最大值
mm_new = tl.maximum(mm, x)
# 对于 x 的每个元素,比较 mm 中对应位置的元素(都是 -inf)和 x 的值,取较大者。
# 结果 mm_new 等于 x: [3, 1, 4, 1, 5, 9, 2, 6]。
# 注意:这时 mm_new 仍是“块内局部最大值”,还不是全局最大值。
# 第3步:更新指数和
if i:
# i=0 时跳过,因为 mm 都是 -inf,调整无意义。
pass
# 第4步:计算当前块的指数并累加
x = tl.exp(x - mm_new)
# 每个线程独立计算它所负责元素的 exp(x_i - max_of_this_block) 值。
ss += tl.where(idx < n_cols, x, 0.0)
# 将计算出的指数值加到 ss 的对应位置上。此时 ss 保存的是第一个块的指数和。
# 第5步:更新旧统计量
mm = mm_new

第一轮结束后,mmss 保存了第一个数据块的统计信息,但它们还不是全局的。

第2次迭代#

# 第二轮 (i=1):
# 第1步:加载数据
idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
# idx 是 [8, 9, 10, 11, 12, 13, 14, 15]。
x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf"))
# 假设加载的值是 [2, 7, 1, 8, 2, 8, 1, 8]。
# 第2步:更新运行时最大值
mm_new = tl.maximum(mm, x)
# 比较上一轮的 mm [3, 1, 4, 1, 5, 9, 2, 6] 和当前的 x [2, 7, 1, 8, 2, 8, 1, 8]。
# 结果 mm_new = [3, 7, 4, 8, 5, 9, 2, 8]。
# 注意:这里mm_new[j] 维护的是前两块中,与 j 相关的数据的最大值。
# 这种设计是为了在接下来的计算中,为每个元素提供一个局部的偏移基准。
# 第3步:动态调整上一轮的指数和
if i: # i=1 时条件为真
ss *= tl.exp(mm - mm_new)
# 因为 mm_new 可能比 mm 大,我们需要把基于较小基准 mm 计算的指数和 ss,
# 通过乘以 exp(mm - mm_new) 这个小于1的因子,调整到以新的、更大的 mm_new 为基准。
# 这一步保证了数值稳定性,并确保 ss 在数学上的正确性。
# 第4步:计算当前块的指数并累加
x = tl.exp(x - mm_new)
ss += tl.where(idx < n_cols, x, 0.0)
# 将第二个块的指数加到刚被调整过的 ss 上。
# 此时 ss 保存的就是前两个块的指数和。
# 第5步:更新统计量
mm = mm_new

最终归约#

For循环结束后,mmss 还保留着长度为 BLOCK_SIZE 的张量形态,我们需要通过归约作得到整行的全局标量最大值和总和。

# 第1步:在 BLOCK_SIZE 内归约 mm,得到全局最大值
mm_new = tl.max(mm) # 在长度为 BLOCK_SIZE 的 mm 张量上执行 max 归约,得到标量 m_global。
# 第2步:再次调整 ss,使其基准对齐到全局最大值 m_global
ss *= tl.exp(mm - mm_new)
# 第3步:归约 ss,得到全局指数和
ss = tl.sum(ss) # 在长度为 BLOCK_SIZE 的 ss 张量上执行 sum 归约,得到标量 d_global。
mm = mm_new # mm 也变成了标量 m_global。

从这个例子中,实际上我们可以提炼出几个编写类似 Triton 内核的通用模式:

  1. 分块加载与循环:这是处理任意大小数据的基本模式。核心结构如下:
for block_start in range(0, N, BLOCK_SIZE):
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
block = tl.load(ptr + offsets, mask=mask)
# ... 在片上 (SRAM) 处理 block ...

这个模式允许我们以固定的、能装进片上内存的“块”为单位来处理超大数据集。

  1. 在线统计算法:当你需要多次遍历计算多个关联的统计量(如最大值和总和)时,可以思考能否将其合并到一次遍历中。模式是:在循环中更新一个统计量时,同步修正另一个关联的统计量。除了Softmax,这个方法在计算归一化的均值和方差时也很有用。

CUDA版本实现#

关于CUDA版本的实现,Oneflow给了很好的解释,大家可以移步Oneflow博客进一步学习

支持与分享

如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!

赞助
LeetGPU习题05:Softmax优化详解
https://dlog.com.cn/posts/leetgpu05/softmax/
作者
杜子源
发布于
2026-05-11
许可协议
CC BY-NC-SA 4.0
最后更新于 2026-05-11,距今已过 46 天

部分内容可能已过时

Profile Image of the Author
杜子源
都是风景,幸会
公告
请狠狠地打赏我,打赏一次,爆更一篇!!
音乐
封面

音乐

暂未播放

0:00 0:00
暂无歌词
分类
标签
站点统计
文章
29
分类
8
标签
11
总字数
81,272
运行时长
0
最后活动
0 天前

目录