CUDA学习之路[10]:扫描算法详解
1. 扫描(前缀和)算法
1.1. 定义
Scan 又称为Prefix Sum,是并行计算中最基础也最重要的collective原语之一。
给定一个数组[x0, x1, ..., xn-1]和一个满足结合律的二元运算符+,扫描操作计算每一个位置上的前缀聚合结果。
如果大家经常刷力扣的话,偶尔会刷到前缀和的算法。
实际上想要并行化前缀和其实并不是一个简单的事情。因为它每一项都依赖前一项的计算结果。
在扫描算法中,主要有两种算法:
- Inclusive Scan(包含扫描)
- Exclusive Scan(排除扫描)
区别主要在于是否要加上当前元素。
举个栗子:
输入: [1,2,3,4]包含扫描输出:[1,3,6,10]排除扫描输出:[0,1,3,6]1.2. 作用
Scan是许多高级算法的基础组件。
- Redix Sort:用scan计算每个digit的位置偏移
- Stream compaction:将predicate为true的算法紧凑排列
- Histogram:累加各个bin的计数
- SpMV:处理CSR格式的行偏移
- 区间求和:略
上边这些看不懂?
看不懂没关系,后续我们也会陆续讲解到和scan有关的一些深度学习算子,如果有用到我们再进一步考虑。
下边我们主要看一下三种主要的并行扫描算法。
2. 并行 Scan 算法
顺序扫描的时间复杂度是O(N),只做N-1次加法,但无法利用GPU的并行性。我们要找一种能并行化的版本,但并行化本身可能引入额外开销,需要在步骤数和总工作量之间权衡。这就引出了“工作效率”的概念:一个并行算法如果渐进地不比顺序算法做更多操作,就是工作高效的。
2.1. Hillis-Steele 算法
算法讲解
该算法在1986年被提出,它的性能并不算最优,我们来看一下是如何实现的。
这个算法思路非常直接,每进行一步,线程和它邻居交流的距离就翻一倍。
[1, 2, 3, 4, 5, 6]Step1: out[i] = in[i] + in[i-1] [1, 1+2, 2+3, 3+4, 4+5, 5+6] [1, 3, 5, 7, 9, 11]Step2: out[i] = out[i] + out[i-2] [1, 3, 5+1, 7+3, 9+6, 11+10] [1, 3, 6, 10, 15, 21]它的加法总次数是O(nlogn),对于一个很大的数组,实际上会比串行算法多做许多加法,导致性能不佳。

下面详细拆解这个 Hillis-Steele 并行扫描(前缀和)的 CUDA 实现,从算法思路、代码执行流程,到潜在的数据冲突与补救措施,逐层分析。
我们优先实现exclusive,inclusive只需要exclusive加上原本的数组,即可实现。
观察上述实现,我们发现每一步中,线程i都需要读取线程i-offset在上一步的计算结果,同时又要把新的结果写回数组。
如果所有的线程共用一个数组,就会发生读写冲突。
因此在这个版本中,我们采用双缓冲来实现,用两个交替的数组彻底消除同一步内的读写依赖。 此外,整个 kernel 只用一个 block(线程数 = n),所有中间数据都放在片上共享内存,避免了反复读写全局内存的延迟。
先声明一个片上共享内存:
extern __shared__ float buf[]; // 动态共享内存,总大小由启动参数决定int tid = threadIdx.x;int pout = 0, pin = 1;buf的实际长度是在 kernel 启动时指定的:2 * n * sizeof(float)。- 逻辑上,我们把它分成两个长度为
n的缓冲区buf[0..n-1]和buf[n..2n-1],并通过索引偏移pin * n/pout * n来切换。
之后我们初始化 Exclusive Scan 的偏移
buf[pout * n + tid] = (tid > 0) ? in[tid - 1] : 0.0f;__syncthreads();因为要做exclusive scan,所以输出比输入整体右移移位,开头补0。
线程 0 写入 0.0f,线程 i (i>0) 写入 in[i-1]。这就得到了初始的“偏移数组”,准备开始迭代累加。
__syncthreads() 保证所有线程都写完了之后,才能进入主循环。
之后我们就开始迭代:
for (int offset = 1; offset < n; offset <<= 1) { pout = 1 - pout; // 交换角色:上次的输出变这次的输入 pin = 1 - pout; // 这句等价于 pin = pout_old
// ① 拷贝一份到新的输出缓冲区 buf[pout * n + tid] = buf[pin * n + tid];
// ② 如果位置允许,加上距离为 offset 的元素 if (tid >= offset) buf[pout * n + tid] += buf[pin * n + tid - offset];
__syncthreads(); // 步间同步}-
pout = 1 - pout; pin = 1 - pout;
这是双缓冲的精髓。假设上一步pout=0, pin=1,执行后pout=1, pin=0。
意味着:上一轮的输出缓冲区0现在变成只读的输入,而新的输出写到缓冲区1。
这样就彻底避免了“读和写发生在同一片内存”的危险。 -
拷贝操作:
buf[pout*n+tid] = buf[pin*n+tid];
先把输入缓冲区的当前值复制到输出缓冲区,相当于线程tid先继承自己上一轮的中间结果。 -
累加前驱:
if (tid >= offset)
如果线程索引不小于当前的 offset,它就需要加上距离为offset的那个元素(即tid - offset位置的值)。
注意:这里读取的仍然是输入缓冲区 (pin),而不是正在写入的pout。
因此即使多个线程同时执行,线程tid的读取源tid - offset不会被其他线程覆盖,没有同一步内的写后读(RAW)冲突。 -
__syncthreads():在每一步结束时调用,确保所有线程都完成了对当前输出缓冲区的写入,才能安全地在下一步将其作为输入来读取。
循环结束后,pout 指向最后一步的输出缓冲区,直接将其写回 out 数组即可。
out[tid] = buf[pout * n + tid];对于启动函数来说,我们只使用单个block,n个线程。动态内存大小刚好容纳两个长度为n的float数据。因此这个 kernel 受限于单 block 的最大线程数(通常为 1024)和共享内存大小(典型 48KB/64KB),只能处理较小规模的数组。
inline void launch_v0(const float* in, float* out, int n) { scan_v0_naive<<<1, n, 2 * n * sizeof(float)>>>(in, out, n);}实际上这个方案不仅算的慢,还同时存在访存冲突和Warp Divergence。
- 在本实现中,两个缓冲区都是连续存放的
float数组,线程tid访问buf[base + tid]。
当n是 32 的倍数时,tid和tid-1(或tid-offset)很可能映射到相同的 bank,造成 2 路甚至多路冲突。
例如 offset=1 时,线程 0 读地址 0,线程 1 读地址 1……表面连续,但只要步长是 32 的约数,就可能同时有多个线程访问同一个 bank。(大家可以思考如何优化) - 对于
if (tid >= offset)会让一部分线程进入条件体,当offset小于 warp 大小(32)时,同一个 warp 内的线程会走不同分支,导致部分线程闲置。这是并行前缀和算法的固有特性,无法完全消除,但不会造成逻辑错误。
参考代码
#pragma once#include <cuda_runtime.h>
__global__ void scan_v0_naive(const float* in, float* out, int n) { extern __shared__ float buf[];
int tid = threadIdx.x; int pout = 0, pin = 1;
buf[pout * n + tid] = (tid > 0) ? in[tid-1] : 0.0f; __syncthreads();
for (int offset = 1; offset < n; offset <<=1) { pout = 1 - pout; pin = 1 - pout; buf[pout * n + tid] = buf[pin * n + tid];
if (tid >= offset) { buf[pout * n + tid] += buf[pin * n + tid - offset]; } __syncthreads(); } out[tid] = buf[pout * n + tid];}
inline void launch_v0(const float* in, float* out, int n) { scan_v0_naive<<<1, n, 2 * n * sizeof(float)>>>(in, out, n);}2.2. Blelloch 算法
为了解决工作效率问题,Blelloch 在1990年提出了一种基于树的工作高效算法,将操作次数降低到O(n)。

前缀和的本质是:
- 位置 i 的前缀和 = 所有在它左边的元素之和。
- 换句话说,如果把数组元素放在一棵二叉树的最底层(叶子),那么任何一段连续的叶子,其“前缀和”都可以由树中的某些部分和拼凑出来。
Blelloch 的第一个巧思就是:
先不要急着算前缀和,而是先构建一棵完整的“部分和”二叉树(归约树),再反向利用这棵树把前缀信息分发下去。
我们来看一下这棵树:
[0..7] ← 整棵树的总和 / \ [0..3] [4..7] / \ / \ [0,1] [2,3] [4,5] [6,7] / \ / \ / \ / \ a b c d e f g h- 每个内部节点代表它覆盖的区间内所有叶子之和。
- 如果我们已经拥有了这棵树上所有节点值,那么任何一个叶子的前缀和,都可以通过从根走到该叶子的路径上、某些节点的值累加得到。
这个算法把“前缀和”分解成“沿树路径累加部分和”。
第一阶段:Upsweep
既然我们需要那棵部分和树,第一步当然就是把它算出来。
这其实就是并行归约,用规约那一套完整即可。
给定 stride offset(代表当前层相邻兄弟节点之间的距离),线程 thid 负责一对兄弟节点:
左子索引 = offset * (2*thid + 1) - 1右子索引 = offset * (2*thid + 2) - 1例如,当 stride=2 时(第二层,offset=2),线程 0 和 1 活跃:
- thid=0:左子 = 2*(1)-1 = 1,右子 = 2*(2)-1 = 3 → 对应 [0,1] 和 [2,3] 的区间和
- thid=1:左子 = 5,右子 = 7 → [4,5] 和 [6,7]
然后在每一层,只需执行:
temp[right] += temp[left]; // 右子 = 右子 + 左子全部线程并行做完后,数组中就同时保存了叶子值和所有内部节点值,而且完全原地,没有额外内存开销。
第二阶段:Downsweep
有了这棵树,下一个问题就是:如何利用它得到每个叶子的前缀和?
对于 exclusive scan,整个数组的“前缀和”起点是 0。所以先把根(代表全数组总和)的值清零,然后从上到下,把前缀信息传递给左右子节点。
具体规则(对应二叉树的性质):
- 对于任何一个内部节点(父节点),它的值
parent代表 它这个区间之前所有元素的和。 - 它的左子节点代表区间的左半部分:左半部分的前缀和,就是
parent(因为左半部分之前还是原来那些元素)。 - 它的右子节点代表区间的右半部分:右半部分的前缀和,必须是
parent + 左半部分的总和(因为右半部分之前要加上整个左半部分)。
所以在 Downsweep 阶段,每一层的操作是:
float t = temp[left]; // 暂存左子原来的值(即左半部分的和)temp[left] = temp[right]; // 左子获得父节点传来的前缀信息temp[right] += t; // 右子 = 原来的右子 + 左半部分的和注意 temp[right] 在这一层刚开始时,还保留着父节点传来的前缀值(因为父节点在上一层已经更新为正确的区间前缀)。通过这三行,前缀信息就沿着树一层层向下传播,最终到达每一个叶子。
总结
| 维度 | 巧妙之处 |
|---|---|
| 算法复杂度 | 总加法次数 ≈ 2n,是对数步数(log n)下的 O(n) 算法,性能接近顺序算法,却拥有巨大并行度。 |
| 树形分解 | 将强顺序依赖的“前缀和”转化为“建树 + 分发”两个阶段,利用二叉树天然适合并行的特点。 |
| 就地操作 | 整个计算完全在原数组上进行,不需要额外的树数据结构,索引只需简单的位运算公式。 |
| 统一框架 | Upsweep 和 Downsweep 的循环结构几乎镜像对称,一个自底向上,一个自顶向下,非常 GPU 友好。 |
| 可扩展性 | 尽管这里演示的是单 block 版本,但该算法天然支持分级:先在多个 block 内做 block 内 scan,再用一个 block 做块间 scan,最后分发回各 block 修正,这就是几乎所有高性能 prefix sum 库(如 CUB、Thrust)的核心方法。 |
参考代码
#pragma once#include <cuda_runtime.h>#include <iostream>
constexpr int NUM_BANKS = 32;constexpr int LOG_NUM_BANKS = 5;
__device__ inline int pad(int i) { return i + (i >> LOG_NUM_BANKS);}
__global__ void scan_v1_blelloch(const float* in, float* out, int original_n, int n2) { extern __shared__ float temp[];
int thid = threadIdx.x; int offset = 1;
// 一个线程可以同时处理两个元素,提高线程本身的计算效率 int ai = thid; int bi = thid + (n2 >> 1);
temp[pad(ai)] = (ai < original_n) ? in[ai] : 0.0f; temp[pad(bi)] = (bi < original_n) ? in[bi] : 0.0f;
// ==================== Phase 1: Upsweep ==================== for (int d = n2 >> 1; d > 0; d >>= 1) { // 确保上一轮都算完了 __syncthreads(); if (thid < d) { int left_idx = offset * (2 * thid + 1) - 1; int right_idx = offset * (2 * thid + 2) - 1;
temp[pad(right_idx)] += temp[pad(left_idx)]; } offset <<= 1; }
// ==================== Phase 2: Downsweep ==================== if (thid == 0) { temp[pad(n2 - 1)] = 0.0f; }
for (int d = 1; d < n2; d <<= 1) { offset >>= 1; __syncthreads(); if (thid < d) { int left_idx = offset * (2 * thid + 1) - 1; int right_idx = offset * (2 * thid + 2) - 1;
int pad_left = pad(left_idx); int pad_right = pad(right_idx);
float t = temp[pad_left]; temp[pad_left] = temp[pad_right]; temp[pad_right] += t; } } __syncthreads();
if (ai < original_n) out[ai] = temp[pad(ai)]; if (bi < original_n) out[bi] = temp[pad(bi)];}
inline void launch_v1(const float* in, float* out, int n) { if (n <= 0) return;
int n2 = 1; while (n2 < n) n2 <<= 1;
if (n2 > 2048) { printf("Error: Single block scan only supports up to 2048 elements.\n"); return; }
int threads = (n2 / 2 > 0) ? (n2 / 2) : 1;
int padded_n = n2 + (n2 >> LOG_NUM_BANKS);
scan_v1_blelloch<<<1, threads, padded_n * sizeof(float)>>>(in, out, n, n2);}能够处理任意block的版本:
#pragma once#include <cuda_runtime.h>#include "scan_v1_blelloch.cuh"
constexpr int V2_BLK = 1024;constexpr int V2_THR = V2_BLK / 2;constexpr int PAD_V2_BLK = V2_BLK + (V2_BLK >> LOG_NUM_BANKS);
__global__ void scan_v2_tile(const float* in, float* out, float* aggs, int N) { // 1. 加载 2 个元素到共享内存 (带 padding) // 2. Upsweep:树状归约,算出段总和 // 3. 把总和存到 aggs[blockIdx.x],并把根节点清零 // 4. Downsweep:向下传播,得段内 exclusive scan // 5. 写回 out __shared__ float tmp[PAD_V2_BLK];
int tid = threadIdx.x;
// 每线程负责的两个本地索引 int li = tid; int ri = tid + V2_BLK / 2;
// 对应的全局索引 int gi = blockIdx.x * V2_BLK + li; int gj = blockIdx.x * V2_BLK + ri;
tmp[pad(li)] = (gi < N) ? in[gi] : 0.0f; tmp[pad(ri)] = (gj < N) ? in[gj] : 0.0f;
int offset = 1; // 上扫 for (int d = V2_BLK >> 1; d > 0; d >>= 1) { __syncthreads(); if (tid < d) { int left = offset * (2 * tid + 1) - 1; int right = offset * (2 * tid + 2) - 1; tmp[pad(right)] += tmp[pad(left)]; } offset <<= 1; }
__syncthreads(); if (tid == 0) { aggs[blockIdx.x] = tmp[pad(V2_BLK - 1)]; // 存下 tile 总和 tmp[pad(V2_BLK - 1)] = 0.0f; // 根节点清零 }
// 下扫 for (int d = 1; d < V2_BLK; d <<= 1) { offset >>= 1; __syncthreads(); if (tid < d) { int left = offset * (2 * tid + 1) - 1; int right = offset * (2 * tid + 2) - 1;
float t = tmp[pad(left)]; tmp[pad(left)] = tmp[pad(right)]; tmp[pad(right)] += t; } } __syncthreads();
if (gi < N) out[gi] = tmp[pad(li)]; if (gj < N) out[gj] = tmp[pad(ri)];}
__global__ void scan_v2_add(const float* aggs, float* out, int N) { // 把 aggs[blockIdx.x] 加到 out 数组上,段内所有元素都加同一个值。 int tid = threadIdx.x; int gi = blockIdx.x * V2_BLK + tid; int gj = gi + V2_BLK / 2; float add = aggs[blockIdx.x]; // 该 tile 需要加上的前缀
if (gi < N) out[gi] += add; if (gj < N) out[gj] += add;}
inline void launch_v2_prefix(float* data, int n) { // 终止条件:问题已缩小到一个 tile 能容纳 if (n <= V2_BLK) { float* dummy; cudaMalloc(&dummy, sizeof(float)); // 不需要 aggs,但 kernel 需要这个参数,给个 dummy scan_v2_tile<<<1, V2_THR>>>(data, data, dummy, n); cudaFree(dummy); return; }
// 1. 分块:计算需要多少个完整 tile int tiles = (n + V2_BLK - 1) / V2_BLK; // 这里是变化的关键! float* totals; cudaMalloc(&totals, tiles * sizeof(float));
// 2. 对每个 tile 做段内扫描,同时收集每个 tile 的总和到 totals scan_v2_tile<<<tiles, V2_THR>>>(data, data, totals, n); // 3. 递归!对 totals 本身求前缀和 launch_v2_prefix(totals, tiles); // 4. 将 totals 中的前缀广播回各个 tile scan_v2_add<<<tiles, V2_THR>>>(totals, data, n);
cudaFree(totals);}
inline void launch_v2(float* d_in, float* d_out, int n) { // 先拷贝输入到输出数组,之后就地计算 cudaMemcpy(d_out, d_in, n * sizeof(float), cudaMemcpyDeviceToDevice); launch_v2_prefix(d_out, n);}3. Triton和Pytorch
对于Pytorch来说,调用的是CUB;而对于Triton来说,提供了Block级别的算子,但是在Block之间,需要我们手动来实现,参考如下:
import torchimport tritonimport triton.language as tl
@triton.jitdef combine(a, b): return a + b
# Pass 1: 仅计算每个 Block 的总和@triton.jitdef reduce_kernel(x_ptr, sums_ptr, n, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets, mask=offsets < n, other=0.0) tl.store(sums_ptr + pid, tl.reduce(x, 0, combine))
# Pass 2: 计算 Block 级别的全局偏移量@triton.jitdef scan_sums_kernel(sums_ptr, offsets_ptr, num_blocks, BLOCK_SIZE: tl.constexpr): offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < num_blocks sums = tl.load(sums_ptr + offsets, mask=mask, other=0.0) scanned = tl.associative_scan(sums, axis=0, combine_fn=combine) tl.store(offsets_ptr + offsets, scanned - sums, mask=mask)
# Pass 3: 在寄存器中局部 Scan 并加上偏移量@triton.jitdef scan_add_kernel(x_ptr, y_ptr, offsets_ptr, n, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n x = tl.load(x_ptr + offsets, mask=mask, other=0.0) scanned = tl.associative_scan(x, axis=0, combine_fn=combine) offset = tl.load(offsets_ptr + pid) # 读取此 block 的全局前缀偏移
tl.store(y_ptr + offsets, scanned + offset, mask=mask)
def scan(x: torch.Tensor) -> torch.Tensor: n = x.numel() y = torch.empty_like(x) BLOCK_SIZE = 1024 num_blocks = triton.cdiv(n, BLOCK_SIZE)
sums = torch.empty(num_blocks, device=x.device, dtype=x.dtype) offsets = torch.empty(num_blocks, device=x.device, dtype=x.dtype)
grid = (num_blocks,) reduce_kernel[grid](x, sums, n, BLOCK_SIZE)
scan_bs = max(16, triton.next_power_of_2(num_blocks)) scan_sums_kernel[(1,)](sums, offsets, num_blocks, scan_bs)
scan_add_kernel[grid](x, y, offsets, n, BLOCK_SIZE) return y
if __name__ == "__main__": x = torch.randn(10000, device='cuda') y_triton = scan(x) max_error = (y_triton - torch.cumsum(x, dim=0)).abs().max().item() print(f"Max error: {max_error:.6e}\nOptimized Scan passed!")支持与分享
如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!
部分内容可能已过时