CUDA学习之路[9]:粗看规约Reduce算法
前置知识
1. 线程层次
详细见之前的文章:GPU 并行本质
Grid(网格)├── Block 0(线程块)│ ├── Warp 0(32个线程)│ │ ├── Thread 0│ │ ├── Thread 1│ │ └── ...│ │ └── Thread 31│ ├── Warp 1│ └── ...├── Block 1└── ...快速回顾:
- Thread:执行的最小单元。每个线程有自己的寄存器。
- Warp:32 个线程为一组,是 GPU 的调度和执行的最小单位。一个 warp 内的 32 个线程共享一个程序计数器,所以它们永远执行同一条指令(SIMT 模型)。
- Block:一组 warp 组成一个 block。一个 block 内的所有线程共享一块共享内存。一个 block 最多 1024 个线程(= 32 个 warp)。
- Grid:整个 kernel 启动时创建的所有 block 组成 grid。所有 block 共享全局内存。
这是 NVIDIA GPU 的硬件设计决定的。从 G80 到 H100,warp size 始终是 32。当你写 if (threadIdx.x < 16) 这样的代码时,同一个 warp 的前 16 个线程走 if 分支,后 16 个线程闲着——这就是 warp divergence,会让性能减半。
2. 内存层次
| 内存类型 | 速度 | 容量 | 可见范围 |
|---|---|---|---|
| 全局内存 | ~1 TB/s(HBM) | 大(GB 级) | 所有线程 |
| 共享内存 | ~200 TB/s(SRAM) | 小 | 同一个 Block |
| 寄存器 | 最快(0 延迟) | 极少 | 单个线程 |
把数据从全局内存搬到共享内存和寄存器,在最快的层级上完成计算,再把结果写回去。
3. SIMT 和 Warp Divergence
GPU 使用的是 SIMT(Single Instruction, Multiple Thread),对应的一个 warp 的 32 个线程在同一时刻执行同一条指令。
当你写了分支代码(if/else),如果同一个 warp 内有些线程走 if,有些走 else,硬件只能先执行 if 分支(让 else 的线程干等),再执行 else 分支(让 if 的线程干等)。
这就是 warp divergence,代码实际上串行化了。
这就是为什么规约算法中要小心处理 if (threadIdx.x < stride) 这类条件,当 stride 小于 32 时,分支就开始在同一 warp 内产生 divergence。
4. 共享内存与访存冲突
共享内存被划分为若干个Bank,通常为32个。每个Bank的每个时钟周期只能服务于一个线程。 如果同一个Warp中的多个线程访问同一个Bank中的不同地址,就会发生Bank Conflict,导致访问串行化。
我们来看一下这三种情况:
- [无冲突]:所有线程访问不同bank,或访问同一bank的同一地址(广播)。
- [2路冲突]:两个线程访问同一bank的不同地址,访存时间翻倍。
- [最坏情况]:32个线程全挤在同一个bank上,串行化32倍。
归约算法中内存访问模式的微小差异,就可能对性能产生巨大的影响。
规约是什么?
通俗地说,规约就是”把一堆东西合并成一个东西。”。
比如你手里有 100 个数,求它们的总和,这就是一次规约。
形式化地讲,给定集合 和一个二元操作 ,规约计算的是:
我们可以发现,归约的特点是:

| 特性 | 逐元素操作 | 规约操作 |
|---|---|---|
| 输出形状 | 与输入相同 | 维度降低 |
| 数据依赖 | 无 | 前后依赖 |
| 并行策略 | 直接映射 | 需要多轮合并 |
| 典型例子 | add, sigmoid, gelu | sum, max, mean, softmax |
常见的规约算子一览:
| 算子 | 操作 | 可能出现的位置 |
|---|---|---|
| Sum | 损失函数求和、BatchNorm 均值 | |
| Max / Min | MaxPooling、Softmax 的稳定化 | |
| Argmax | 比较值,保留索引 | 分类模型最后的预测 |
| Mean | sum 再除以 N | LayerNorm / BatchNorm |
| Variance | 先 sum 再求方差 | LayerNorm / BatchNorm |
| Dot Product | 逐元素乘再 sum | 矩阵乘法的核心 |
其中 Sum 是所有规约算子的基础,它满足结合律和交换律,单位元是 0,硬件实现高效,而且几乎所有深度学习框架都有它。
本文的全部讨论都将围绕向量求和展开。
优化之旅

1. Baseline串行求和
def sequential_sum(arr): result = 0 for x in arr: result += x return result最简单的求和,基本上大家都可以写出来。 N 个元素,N-1 次加法,是线性链式结构,每一步都依赖前一步的结果。
那想要让这个算法跑的更快,首先得改变计算的拓扑结构。
2. 树形规约
既然串行是链式的,我们能不能把它变成树形?
树形规约的基本思想就是两两相加。
具体来说就是:

__global__ void reduce_vo(const float* in, float* out, int N) { __shared__ float sdata[1024];
int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x + tid;
sdata[tid] = (idx < N) ? in[idx] : 0.0f; __syncthreads();
// 步长翻倍 for(int s = 1; s < blockDim.x; s *= 2) { if (tid % (2 * s) == 0) { sdata[tid] += sdata[tid + s]; } __syncthreads(); }
// 每个block的结果用原子加法写回全局内存 if (tid == 0) atomicAdd(out, sdata[0]);}我们来拆解一下上边这个代码。
2.1. shared
为什么我们要在这里用到共享内存呢?实际上我们当然可以每个线程从全局内存里读取数字,然后在寄存器加加减减,最后把结果写回全局内存即可。
但是共享内存会给同一个Block里的线程提供一个空间,它的物理位置在SM内部,延迟非常低,带宽非常高。 一个Block内的所有线程几乎都可以零成本的读写这块共享空间,折让线程的数据交换快了非常多。
并且我们声明了__shared__ float sdata[1024],这意味每个block都会有一份sdata,大小是1024,如果我的block启动1024个线程,那每个线程可以有一个专属的sdata[tid]的格子来存放从全局内存搬运过来的那个数。
int tid = threadIdx.x;int idx = blockIdx.x * blockDim.x + tid;sdata[tid] = (idx < N) ? in[idx] : 0.0f;__syncthreads();这一步表示每个线程都在计算自己的idx,然后把自己负责的那个元素从in数组里拿进来,然后放进共享内存对应的小格子里。
如果数组长度N的不是Block大小的整数倍,最后几个线程可能会越界,所以通过(idx < N) ? in[idx] : 0.0f来做边界保护。
最终通过__syncthreads()确保所有线程都把自己的数据放妥当,等到所有线程都就绪后才开始进行下一步。
2.2 循环部分是如何规约的?
现在共享内存里已经摆好每个线程带进来的数,接下来的任务是在这个block内部把这些数相加,最后浓缩成为一个和。
最基础的算法采用的是树状折半规约,这里采用循环来实现。
for(int s = 1; s < blockDim.x; s *= 2) { if (tid % (2 * s) == 0) { sdata[tid] += sdata[tid + s]; } __syncthreads();}第一次循环,s=1,条件tid % 2 == 0,这里把所有偶数编号的线程唤醒。
例如线程0把sdata[1]加到自己sdata[0]上,
线程2把sdata[3]加到sdata[2]上,以此类推。
等到所有干活的线程执行完,我们再进行一次同步,确保所有的假发都写进到了共享内存,其他线程才能够进入到下一轮。
第二次s=2,只有tid % 4 == 0的线程工作,步长翻倍,间隔拉大,每一轮参与计算的线程数减半,而被合并的数据块长度翻倍,如此往复。
直到s达到blockDim.x,循环结束,此时sdata[0]中装的就是整个block所有元素的和。
2.3. 最终输出
if (tid == 0) atomicAdd(out, sdata[0]);最终通过原子加法把这份局部和累加到全局输出out中。
因为有多个block在同时做这件事,所以需要使用atomicAdd来确保不会发生读写冲突。
2.4. 如何优化呢?
现在逻辑上,这个规约算法已经没有任何问题了,但是从性能上来看,实际上差的还是蛮远的。
那究竟差在哪里呢?
实际上,取模运算在GPU上是非常慢的,并且2*s在编译器中没有办法把它进行高效优化成为移位和加法,绝大部分都在排队等待除法结果,计算资源被浪费。
共享内存虽然快,但如果不注意,很容易发生访存冲突。 它被分成 32 个 bank,每个 bank 在每个时钟周期内只能服务一个线程的访问。如果同一个 warp 里的两个线程访问了同一个 bank 的不同地址,就会产生冲突,访问会被串行化。 回到我们的循环:第一轮 s=1,干活的是偶数线程 0,2,4,…,它们访问的地址是 (0,1), (2,3), (4,5),相邻干活线程之间隔了一个线程,但访问的 bank 还勉强算分散。 可随着 s 增大,干活线程越来越稀疏,到 s=16 时,参与运算的线程 0 和线程 32 访问的地址分别落在 bank 0 和 bank 32——注意 bank 数是 32,bank 32 在硬件上就是 bank 0 的另一行,于是 bank 0 被同时访问,冲突发生。 再往后,s=32 时干活线程的间隔更大,几乎所有访问都挤压在极少数 bank 上,冲突严重到让共享内存的带宽优势荡然无存。
3. 优化取模
__global__ void reduce_v1(const float* in, float* out, int N) { __shared__ float sdata[1024]; int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
sdata[tid] = (idx < N) ? in[idx] : 0.0f; __syncthreads();
for(int s=1; s < blockDim.x; s *= 2) { int index = 2 * s * tid; if (index < blockDim.x) { sdata[index] += sdata[index + s]; } __syncthreads(); }
if (tid == 0) atomicAdd(out, sdata[0]);}3.1. 使用乘加代替取模
这一节我们来着重优化一下取模运算。
我们可以让每个线程直接算出自己要服务的起始下标index。 这个计算包含一次整数乘法和一次加法,整数乘法在GPU上的吞吐远远高于除法和取模。 更主要的是,编译器看到这种固定模式的乘法和移位,往往能够进一步优化指令组合。

3.2. 优化方案

虽然我们解决了计算的延迟,但是访存本身仍然是瓶颈。
当s=1时,我们可以看到每个Bank被两个线程访问,会发生2路冲突。 而当s=16时,我们发现访问索引为0,32,64,…。这32个线程都会在同一个Bank下访问,这就造成了32路冲突,已经是完全串行化了。
下边我们再来尝试解决一下访存冲突问题。
4. 优化访存

__global__ void reduce_v2(const float* in, float* out, int N) { __shared__ float smem[1024]; int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x + tid;
smem[tid] = (idx < N) ? in[idx] : 0.0f; __syncthreads();
for (int s = BLOCK >> 1; s > 0; s >>= 1) { if (tid < s) smem[tid] += smem[tid + s]; __syncthreads(); }
if (tid == 0) atomicAdd(out, smem[0]);}可以看到,在这个地方,我彻底修改了循环方式,核心改动就在这几行中:
for (int s = BLOCK >> 1; s > 0; s >>= 1) { if (tid < s) smem[tid] += smem[tid + s]; __syncthreads();}我们要弄明白一件事,它是如何消灭Bank conflict的?
实际上,我们的步长不再是从小到大,原本是1,2,4,8,16这个样子。
而我们现在的计算方式是s = BLOCK / 2,比如BLOCK=1024,那么s一开始就是512.
- 第一轮,我们发现线程0-511干活,剩下的闲着。
- 第二轮,线程0-255干活,剩下的闲着。
- …
- 最后一轮s=1,只有线程0干活。
并且我们干活的线程是连续的,但是每个线程访存的地址却发生了变化,是连续的smem[tid]和另一段连续的smem[tid+s]。
例如这个图所示
- 线程0:地址0和地址512
- 线程1:地址1和地址513
- …
- 线程31:地址31和地址543
整个规约过程中,每一轮都减半,确保bank冲突从头到尾都是零。
5. 见证功力的优化方案
上一个版本实际上已经是相当不错的规约实现了,但是从现在开始才是我们真正优化的功力体现。
__global__ void reduce_v3(const float* in, float* out, int N) { __shared__ float smem[1024]; int tid = threadIdx.x; int idx = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
float val1 = (idx < N) ? in[idx] : 0.0f; float val2 = (idx + blockDim.x < N) ? in[idx + blockDim.x] : 0.0f; smem[tid] = val1 + val2; __syncthreads();
for (int s = BLOCK >> 1; s > 0; s >>= 1) { if (tid < s) smem[tid] += smem[tid + s]; __syncthreads(); }
if (tid == 0) atomicAdd(out, smem[0]);}乍一眼看过去和v2版本规约循环一模一样,但是我们又回过头对数据加载模块进行优化。
之前我们在每个线程搬运的时候,只从全局内存搬运一个元素。 但是实际上我们可以每个线程多读取一点,例如一口气读取两个元素,并且可以直接在本地寄存器里把它们加起来。
int idx = blockIdx.x * (blockDim.x * 2) + threadIdx.x;注意这里全局下标的步长已经不是blockDim.x,而是对其乘以2。
这就意味着同样数量的线程,v3覆盖的数据是v2的两倍,或者我们启动的Block可以减半。
这直接带来了两个收益:
- 减少atomicAdd的竞争。
- 提升计算密度。寄存器的带宽远超共享内存,并且没有bank conflict,线程在等待全局数据内存的同时就可以做这件加法。
当然可以,而且很多高性能库就是这么干的。但是增加单线程的负载同时也会减少活跃的线程块数,可能降低硬件占用率。因此需要在“每线程工作量”和“并行度”之间找到一个平衡点。v3 的 ×2 只是一个讨巧的起步优化。
6. 展开最后一个Warp
优化了数据加载,我们再掉头来优化最后一次加载。
实际上,如果最后一次加载不满足32个线程,它还是性能会下降,我们实际上可以把这个计算直接展开,不要使用循环计算,直接给他打表计算(打表就是把所有的情形都列出来)。
__device__ void warpReduce_v4(volatile float* smem, int tid) { smem[tid] += smem[tid + 32]; smem[tid] += smem[tid + 16]; smem[tid] += smem[tid + 8]; smem[tid] += smem[tid + 4]; smem[tid] += smem[tid + 2]; smem[tid] += smem[tid + 1];}
__global__ void reduce_v4(const float* in, float* out, int N) { __shared__ float smem[1024]; int tid = threadIdx.x; int idx = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
float val1 = (idx < N) ? in[idx] : 0.0f; float val2 = (idx + blockDim.x < N) ? in[idx + blockDim.x] : 0.0f; smem[tid] = val1 + val2; __syncthreads();
// 规约到只剩 64 个数时停手 for (int s = BLOCK >> 1; s > 32; s >>= 1) { if (tid < s) smem[tid] += smem[tid + s]; __syncthreads(); }
// 最后 64 → 1 的规约,完全在 warp 内完成 if (tid < 32) warpReduce_v4(smem, tid);
if (tid == 0) atomicAdd(out, smem[0]);}6.1. 提前停手
我们发现循环条件变了,当规约到s=64时就停止循环,此时共享内存的前64个严肃包含了64个和,那我们剩下的就完全可以手动来实现。
if (tid < 32) warpReduce_v4(smem, tid);注意这里只让前 32 个线程(即第 0 号 warp)进入函数,其他 warp 此时已经解放了,不需要再跟着做同步。
__device__ void warpReduce_v4(volatile float* smem, int tid) { smem[tid] += smem[tid + 32]; // 64 → 32 smem[tid] += smem[tid + 16]; // 32 → 16 smem[tid] += smem[tid + 8]; // 16 → 8 smem[tid] += smem[tid + 4]; // 8 → 4 smem[tid] += smem[tid + 2]; // 4 → 2 smem[tid] += smem[tid + 1]; // 2 → 1}这里每一步的步幅和之前保持一致,但是去掉了同步。
warp 内的 32 条线程在硬件上是锁步执行的,用人话来说,它们同步同一条指令。这就意味着只要让前 32 个线程一起进入这个函数,它们内部的每一步加法对外都是天然同步的,不存在写后读的风险,也不需要显式同步。
你在代码里看到的是 6 次累加,硬件上就是 6 条指令,干净利落。
6.2. volatile的作用
你有没有注意到smem的参数类型是volatile float*?
__device__ void warpReduce_v4(volatile float* smem, int tid)因为硬件保证不了编译器会老老实实地把每次读写都落到共享内存上。
编译器会做一种非常激进的优化:把频繁访问的共享内存变量缓存在寄存器里。
考虑这样一段未加 volatile 的代码:
smem[tid] += smem[tid + 32];smem[tid] += smem[tid + 16];从编译器的角度看,smem[tid] 在第一句之后已经被加载到寄存器了,第二句再用的时候,它完全可能直接用寄存器里的旧值,而不是重新去共享内存里读。但是 smem[tid] 在第一句中已经被写回了新值,而且读 smem[tid + 16] 的线程可能就在同一个 warp 内的另一个位置,它需要看到的是刚刚写进共享内存的那个最新结果。
一旦编译器把 smem[tid] 缓存到寄存器,就会发生不得鸟的事:
- 线程 0 执行完第一句,smem[0] 的新值被写回共享内存。
- 线程 0 接着执行第二句时,编译器可能直接用寄存器里的旧 smem[0] 值加上 smem[16],共享内存里的最新 smem[0] 直接被无视。
- 最终 smem[0] 得到的是一个错误的部分和。
volatile 告诉编译器每次用到这个变量时,必须从共享内存里老老实实地读写。
加上这个关键字之后,编译器不再对 smem 指向的地址做任何寄存器和缓存优化。每一步累加都从共享内存里取最新的数据,算完再立刻写回去。warp 内部的线程能保证它们看到彼此的最新写入,计算也就完全正确了。
支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!