LeetGPU习题05:Softmax优化详解
什么是Softmax?
1. 数学定义

这里 是自然对数的底数(约等于 2.718)。这个公式保证了:
- (每个值都在 0 到 1 之间)
- (所有值加起来等于 1)
这两个性质正好满足概率分布的数学要求,因此 softmax 的输出可以直接解释为多项分布的概率。
和Softmax对应的是argmax,argmax就是我们通常理解的max。
import torcha = torch.randn(10)b = torch.max(a)print(a)print(b.item())但是argmax在求导方面存在诸多问题,例如不可导等等。
2. 历史渊源
在现在的Transformer架构中,softmax无处不在,但是它最早是哪里来的呢?
1868年,玻尔兹曼在其奠基性统计力学有关玻尔兹曼分布的论文中提出了Softmax函数。 1902年,Gibbs在《统计力学基础原理》中对Softmax进行了形式化与推广。
实际上,玻尔兹曼分布定义了位于不同能量级的粒子数量的概率分布,该分布期望:
在能量与粒子数不变的前提下,能量级状态数量最大化。
在之后的逻辑回归中,研究者需要对具有多个无序类别的结果进行建模,因此它们将线性预测器的输出自然地对数-几率比参数化,由此推导出的概率公式就是Softmax。
在1989年,John S.bridle在论文中首次把这种网络输出指数化再归一化的层命名为Softmax,以区别于选择最大值的argmax。
3. 为什么是指数函数?
使用指数函数主要是两个原因:
- 与交叉熵结合。 当我们用 Softmax 的输出作为预测概率,然后搭配 交叉熵损失 函数时,损失函数对某个未归一化输出 的梯度会变得极其简洁:
(其中 是真实标签的 one-hot 编码)这意味着反向传播的梯度信号就是“预测值”与“真实值”的误差,这种线性减法形式使得训练极为稳定,避免了梯度消失或爆炸。
- 广义线性模型的规范链接。 在指数分布族的框架下,所有的概率分布(正态分布、泊松分布)都可以归纳为指数分布簇。 如果我们面对一个多分类问题,数学推导证明链接原始线性输出与最终概率分布的最优且唯一的规范链接函数,恰好就是Softmax。 这意味着使用 Softmax 和交叉熵时,我们实际上是在正确指定的统计模型下进行最大似然估计。
4. 继续深挖Softmax
在神经网络分类层中,最后的全连接层输出通常被称为 Logits(未归一化的对数概率)。Logits 的取值范围是整个实数域 ,而概率必须位于 [0, 1] 且和为 1。Softmax 在数学上实现了一个光滑的、满射的映射:
它将无限延伸的实数空间“压缩”成一个 (K-1) 维的概率单纯形。
注意 Softmax 的分母 被称为 Log-Sum-Exp 函数,它是 max 函数的一个光滑近似:
而 Softmax 恰好就是 Log-Sum-Exp 函数对每个分量 的偏导数。这揭示了 Softmax 的另一个数学身份:它是 算子的光滑梯度。
当存在一个输入 显著大于其他元素时, 会在分母中占绝对主导地位,对应的 会无限趋近于 1,而其他 趋近于 0。Softmax 就退化成一个近乎 one-hot 的分布,表现得像 argmax。反之,当输入值差异很小时,它会给出一个接近均分的概率分布。通过引入 温度参数 :
我们可以控制输出的“软硬程度”:
- :趋近于 argmax,分布极尖锐。
- :趋近于均匀分布,极为平滑。
5. 为什么Transformer中要用Softmax
在 Transformer 架构中,Softmax更是自注意力机制的核心组件。
很多市面上的八股都会告诉大家Softmax的若干作用,它的公式如下:
注意力机制本质上是信息检索的连续化。 设想一个包含键值对 的数据库,给定查询 后,最直接的做法是找到相似度最高的键,并输出其对应的值。
这等价于对相似度分数执行 argmax 后做点积。但是argmax 分段常数并且不可导,阻碍了基于梯度的端到端训练。
Softmax 正是 argmax 的一种光滑化、处处可导的近似。 它以连续概率的形式分配注意力。 相似度越高的键,获得指数级更大的权重,但所有键都会保留非零梯度通路。 这种“软查找”既保留了选择性聚合的能力,又让误差信号能无阻碍地反向传播到查询和键的表示学习中。
原始的注意力分数矩阵 中,元素可以是任意实数,数值范围不受约束。 如果直接用这些裸分去加权 矩阵,随着序列长度增加或网络加深,输出向量的范数会不受控制地增长或衰减,导致训练不稳定。
Softmax 通过分母 将所有分数映射到 (0,1) 之间,且每行的和为 1。于是注意力输出退化为 中各行向量的一种凸组合,即加权平均值。
由于凸组合的范数严格被输入范数的最大值所限制,无论序列有多长,注意力层的输出量级始终保持在稳定区间。这是在深层架构中堆叠数十层注意力而无需额外范数约束的数学基础。
既然只需要加权平均,那更简单的 归一化(L1范数)为何不被采用?
原因在于长序列场景下的信噪分离需求。 线性归一化会将注意力分散到大量低相关度的 token 上,每个 token 都分得微小但不可忽略的权重。这种“平均主义”会稀释核心信息的表征,对推理和长程依赖非常不利。
指数函数 作为凸的、严格正且增长迅速的激活函数,天然具备对比度放大的能力。 当分数间存在微小差异时,指数运算会显著拉大权重差距. 高相似度的 token 获得压倒性权重,低相似的 token 权重急剧趋近于零。 效果上,Softmax 自动实现了自适应稀疏性。 模型无需预先指定哪几个 token 应该被“看见”,而是通过指数竞争机制动态地让少数关键元素主导输出,同时屏蔽其余噪声。
为什么要引入缩放因子
我们来看这个例子:
想象一个句子有三个token:[I love you]。
当前哟啊计算出第一个token的注意力,模型算出了它与三个key的点积分数:[2, 5, 10]。
如果我们直接送进Softmax,输出则是[0.0000015%,0.00003%,99.99997%]
这几乎就是一个one-hot向量。第三个token完全拿走了所有的注意力,前两个token的梯度则几乎为0。 这会导致输出对输入的任何微小变化都毫无反应。
那我们现在加入缩放因子,假设,则缩放后的分数变为:[0.25, 0.625, 2.5]
这个时候输出就变成了:[7.5%, 10.9%, 81.6%]
虽然仍然是第三个值最大,但是剩下两个仍然保留了可观的份额,输出不再是one-hot。
为什么恰好是这个缩放因子?
实际上,对于维度的随机向量,假设每个分量独立且方差为1,点积的方差恰好等于。因此点积标准差恰好是这个,也就是将点积的方差重新固定为 1。
- 缩放让分数分布恢复成“常温”状态,恰好落在 Softmax 既能拉开差距、又不至饱和的区间。
- 如果除以更小的数(如 ),分布仍偏宽,模型会重新倾向 one-hot;
- 如果除以更大的数(如 ),分布收得过窄,注意力会变得过于均匀,变成“什么都注意,什么都记不住”。
Pytorch版本实现
1. naive版本
好了,我们按照数学公式,直接来写吧:
def naive_softmax(x, dim=-1): exp_x = torch.exp(x) return exp_x / exp_x.sum(dim=dim, keepdim=True)OK, 带入[1, 100, 100000]试一下,你试试呢~
(mycuda) 012 python psoftmax.pytensor([0., nan, nan])(mycuda) 012这段代码看起来没什么问题,但是只要 x 中出现较大的元素exp就会直接溢出为 inf,导致整个 Softmax 输出变成一堆 nan。
2. Safe Softmax
为了解决这个问题,论文中提出了它的解决方案:

我来给大家翻译一下。
- 找到这个向量/矩阵中要求行的最大值M。
- 将所有含有的式子变换为
- 带入原式子
完整的代码如下,其中有注释,不再赘述:
import torch
def softmax(x, dim=-1): """ softmax(z)_i = exp(z_i) / sum_j exp(z_j) """ # 第 1 步:求最大值 # 沿目标维度 dim 求最大值,keepdim=True 保持原维度数, # 便于后续广播减法。x_max 的形状会在 dim 维度变成 1,其他维不变。 x_max = x.max(dim=dim, keepdim=True)[0] # [0] 取出最大值
# 第 2 步:平移输入数据,减去最大值 # 数学上等价于将分子分母同除以 exp(x_max),防止指数爆炸。 # 因为 x_max 是沿 dim 维的最大值,x - x_max 的所有元素 ≤ 0, safe_x = x - x_max # 广播机制
# 第 3 步:指数计算 # 对平移后的张量逐元素求 exp。 exp_x = torch.exp(safe_x)
# 第 4 步:求和 # 沿 dim 维度对 exp_x 求和,同样 keepdim=True 保持维度以便广播。 sum_exp = exp_x.sum(dim=dim, keepdim=True)
# 第 5 步:归一化得到最终概率 output = exp_x / sum_exp
return output
def naive_softmax(x, dim=-1): exp_x = torch.exp(x) return exp_x / exp_x.sum(dim=dim, keepdim=True)
if __name__ == "__main__": # 一维向量 x1 = torch.tensor([1.0, 2.0, 3.0]) p1 = softmax(x1) print("一维 softmax:\n",p1) print(f"Torch Softmax:\n{torch.softmax(x1, dim=-1)}") print("求和 =", p1.sum().item()) # 应为 1
# 二维矩阵 x2 = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) p2 = softmax(x2, dim=-1) # 对每一行做 softmax print("\n二维 softmax:\n",p2) print(f"Torch Softmax:\n{torch.softmax(x2, dim=-1)}") print("每行求和 =", p2.sum(dim=-1)) # 结果 [1, 1]Triton版本实现
1. Safe Softmax版本实现
我们已经学习过了element wise和reduce,大家可以仔细分析一下,softmax的运算都分为哪些呢?

从论文中,我们将其进行分解为Element wise和reduce试一下?
回顾数值稳定的 Safe Softmax 公式:
它由以下 5个子运算 组成,可以按模式归为两类:
| 步骤 | 运算 | 类型 | 说明 |
|---|---|---|---|
| ① 求最大值 | reduction | 跨元素归约 | 需要在整行/整向量上进行比较归约,产生一个标量。 |
| ② 平移 | element-wise | 逐元素 | 每个元素独立减去同一个标量,完全并行。 |
| ③ 指数 | element-wise | 逐元素 | 每个元素独立计算指数函数。 |
| ④ 求和 | reduction | 跨元素归约 | 将指数结果累加为一个标量。 |
| ⑤ 归一化 | element-wise | 逐元素 | 每个指数除以上述标量,得到最终概率。 |
论文指出,几乎所有深度学习框架使用的 Safe Softmax 都执行 三次遍历:
- 第一遍求最大值 ;
- 第二遍求指数和 ;
- 第三遍计算输出 。
我们用triton来实现一下
@triton.jitdef kernel_softmax_fuse( x_ptr, x_row_stride, y_ptr, y_row_stride, n_cols, BLOCK_SIZE: tl.constexpr,): row_idx = tl.program_id(0) x_ptr += row_idx * x_row_stride y_ptr += row_idx * y_row_stride idx = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf")) x = tl.exp(x - tl.max(x)) eps = float(1e-9) x /= tl.maximum(tl.sum(x), eps) tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_fuse(x): n_rows, n_cols = x.shape y = torch.empty_like(x) kernel_softmax_fuse[[n_rows]]( x, x.stride(0), y, y.stride(0), n_cols, BLOCK_SIZE=triton.next_power_of_2(n_cols), num_warps=32, ) return y和原版不同的是,它只做 一次加载,然后直接在片上完成:
tl.max(x)→ 求最大值x - tl.max(x)→ 平移tl.exp(...)→ 指数tl.sum(x)→ 求和(归约)tl.maximum(..., eps)→ 防除零- 除法和写回
对于参数来说,
@triton.jitdef kernel_softmax_fuse( x_ptr, # 输入张量的基地址 x_row_stride, # 输入行步幅 y_ptr, # 输出张量的基地址 y_row_stride, # 输出行步幅 n_cols, # 列数(每行的元素个数) BLOCK_SIZE: tl.constexpr, # 编译时常量,决定一次加载的元素数):Triton 是显式内存编程模型,kernel 接收到的是原始指针(即张量首元素的地址),而不是 tensor 对象。我们需要手动计算偏移来访问特定行的数据。
stride 是连续两行起点之间的元素个数(或者说偏移量)。在 PyTorch 中,x.stride(0) 就给出了这个值。(我们点击stride就可以查看函数定义)
def stride(self, dim: None = None) -> tuple[_int, ...]: r""" stride(dim) -> tuple or int
Returns the stride of :attr:`self` tensor.
Stride is the jump necessary to go from one element to the next one in the specified dimension :attr:`dim`. A tuple of all strides is returned when no argument is passed in. Otherwise, an integer value is returned as the stride in the particular dimension :attr:`dim`.
Args: dim (int, optional): the desired dimension in which stride is required
Example::
>>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) >>> x.stride() (5, 1) >>> x.stride(0) 5 >>> x.stride(-1) 1 """- 对于连续内存的矩阵 shape
[n_rows, n_cols],stride(0)通常等于n_cols。 - 但如果矩阵是转置视图、
permute或切片后的结果,stride可能不等于列数。此时内存不连续,两行之间会有间隔。
普通 Safe Softmax 每次都要遍历全局内存三次:读 max、读并累加 sum、读并写归一化结果。这三个遍历会严重消耗内存带宽。
Triton这个版本的思路是:
- 如果一整行数据能放入 寄存器/共享内存(即
n_cols不大),我们就可以只读一次、只写一次。 - 在片上完成所有计算,内存访问被压缩到理论极限(2 次:1 读 + 1 写)。
- 通过
tl.max和tl.sum这些高效的归约原语(内部由 warp shuffle 实现),不必把中间结果写回全局内存。
但这也带来了限制
BLOCK_SIZE必须容纳整行。当n_cols很大(例如 100k 长序列)时,硬件没有足够的寄存器/共享内存给每个线程,kernel 会直接无法启动。这就解释了为什么需要后面的tile和online版本。
2. Tile版本实现
Fuse 版本的核心理念是 整行一次性加载,在寄存器/共享内存中完成所有运算,然后一次写回。
它做到了最低的全局内存访问:每行 1 读 + 1 写。
但是它的致命问题是片上容量。 假设某一行非常大,例如1_000_000_000这么大的一行,那么Kernel只能去在HBM上进行交互,会让整个效率骤降。
因此我们需要设计出一种能够处理一行数据非常大的方案。
一个自然的想法就是分块处理,这就是我们提出的Tile版本。
Tile版本就是按照把原本的数据分块进行读取:
- 第一遍:遍历所有块,计算全局最大值 。
- 第二遍:用已知的 ,再次遍历所有块,累加得到全局归一化分母 。
- 第三遍:用 ,再次遍历所有块,计算最终输出并写回。
代码实现:
def triton_softmax_dim1_fuse(x): n_rows, n_cols = x.shape y = torch.empty_like(x) kernel_softmax_fuse[[n_rows]]( x, x.stride(0), y, y.stride(0), n_cols, BLOCK_SIZE=triton.next_power_of_2(n_cols), num_warps=32, ) return y
@triton.jitdef kernel_softmax_tile( x_ptr, x_row_stride, y_ptr, y_row_stride, n_cols, BLOCK_SIZE: tl.constexpr, CACHE_OPT: tl.constexpr,): row_idx = tl.program_id(0) x_ptr += row_idx * x_row_stride y_ptr += row_idx * y_row_stride
mm = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)): idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf")) mm = tl.maximum(mm, x) mm = tl.max(mm)
ss = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
if CACHE_OPT: for i in range(tl.cdiv(n_cols, BLOCK_SIZE) - 1, -1, -1): idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf")) x = tl.exp(x - mm) ss += x else: for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)): idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf")) x = tl.exp(x - mm) ss += x
ss = tl.sum(ss) eps = float(1e-9) ss = tl.maximum(ss, eps)
for i in range(0, tl.cdiv(n_cols, BLOCK_SIZE)): idx = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE x = tl.load(x_ptr + idx, mask=idx < n_cols, other=-float("inf")) x = tl.exp(x - mm) / ss tl.store(y_ptr + idx, x, mask=idx < n_cols)
def triton_softmax_dim1_tile(x, cache_opt=True): n_rows, n_cols = x.shape y = torch.empty_like(x) kernel_softmax_tile[[n_rows]]( x, x.stride(0), y, y.stride(0), n_cols, BLOCK_SIZE=2**14, CACHE_OPT=cache_opt, num_warps=32, ) return y2.1. 分块找最大值
我们维护一个长度为BLOCK_SIZE的向量,并且初始化为-inf,用来保存当前所有线程各自观察到的最大值。之后不断迭代这个mm。
具体示意图如下:

为了避免每次加载都要跨线程同步,只有在最后循环结束后做一次最终规约。
之后我们需要同样计算全局和,不过根据代码可以防线,我们此时选择倒序。 正序还是倒序无关数学结果,但是倒序可以提升缓存命中率。 循环结束后,和max一样,也是最后才进行规约。
- 第一遍扫描结束的那一刻,最后几个块的输入数据很可能还在GPU的 L2 缓存里。
- 第二遍若正序开始,会先加载首部块,此时尾部块可能已被逐出,再次加载发生缓存缺失。
- 倒序则直接从尾部开始,这些数据大概率命中缓存,减少 HBM 访问,提升带宽利用率。

最终遍历所有块,将结果写入输出张量对应的行。
实际上经过测试,在小规模数据中,Tile版本速度明显效率Fused速度,但是在大规模数据中,性能有明显提升。
Tile速度慢的主要原因是因为访存次数的增多,会导致理论速度上限下降,但是这也是没办法的,毕竟Fused无法处理长行。 因此混合场景往往是: 对于小规模数据使用Fuse,对于大规模数据使用Tile。
性能对比测试我放在最后。
3. Online版本实现
Tile相比较Fused来说,解决了行规模大计算缓慢的问题,但是又增加了访存次数,那有没有更好的优化方式呢?
开头我放了一篇18年的文章,大家可以阅读一下,里边详细讲到了优化方案,那么接下来我也稍微讲解一番。
支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!