CUDA学习之路[10]:扫描算法详解

4396 字
22 分钟
CUDA学习之路[10]:扫描算法详解
已完成

1. 扫描(前缀和)算法#

1.1. 定义#

Scan 又称为Prefix Sum,是并行计算中最基础也最重要的collective原语之一。

给定一个数组[x0, x1, ..., xn-1]和一个满足结合律的二元运算符+,扫描操作计算每一个位置上的前缀聚合结果。

如果大家经常刷力扣的话,偶尔会刷到前缀和的算法。

实际上想要并行化前缀和其实并不是一个简单的事情。因为它每一项都依赖前一项的计算结果。

在扫描算法中,主要有两种算法:

  • Inclusive Scan(包含扫描)
  • Exclusive Scan(排除扫描)

区别主要在于是否要加上当前元素。

举个栗子:

输入: [1,2,3,4]
包含扫描输出:[1,3,6,10]
排除扫描输出:[0,1,3,6]

1.2. 作用#

Scan是许多高级算法的基础组件。

  • Redix Sort:用scan计算每个digit的位置偏移
  • Stream compaction:将predicate为true的算法紧凑排列
  • Histogram:累加各个bin的计数
  • SpMV:处理CSR格式的行偏移
  • 区间求和:略

上边这些看不懂?

看不懂没关系,后续我们也会陆续讲解到和scan有关的一些深度学习算子,如果有用到我们再进一步考虑。

下边我们主要看一下三种主要的并行扫描算法。

2. 并行 Scan 算法#

顺序扫描的时间复杂度是O(N),只做N-1次加法,但无法利用GPU的并行性。我们要找一种能并行化的版本,但并行化本身可能引入额外开销,需要在步骤数和总工作量之间权衡。这就引出了“工作效率”的概念:一个并行算法如果渐进地不比顺序算法做更多操作,就是工作高效的。

2.1. Hillis-Steele 算法#

算法讲解#

该算法在1986年被提出,它的性能并不算最优,我们来看一下是如何实现的。

这个算法思路非常直接,每进行一步,线程和它邻居交流的距离就翻一倍。

[1, 2, 3, 4, 5, 6]
Step1: out[i] = in[i] + in[i-1]
[1, 1+2, 2+3, 3+4, 4+5, 5+6]
[1, 3, 5, 7, 9, 11]
Step2: out[i] = out[i] + out[i-2]
[1, 3, 5+1, 7+3, 9+6, 11+10]
[1, 3, 6, 10, 15, 21]

它的加法总次数是O(nlogn),对于一个很大的数组,实际上会比串行算法多做许多加法,导致性能不佳。

Hillis-Steele算法示意图
Hillis-Steele算法示意图

下面详细拆解这个 Hillis-Steele 并行扫描(前缀和)的 CUDA 实现,从算法思路、代码执行流程,到潜在的数据冲突与补救措施,逐层分析。


我们优先实现exclusive,inclusive只需要exclusive加上原本的数组,即可实现。

观察上述实现,我们发现每一步中,线程i都需要读取线程i-offset在上一步的计算结果,同时又要把新的结果写回数组。 如果所有的线程共用一个数组,就会发生读写冲突。

因此在这个版本中,我们采用双缓冲来实现,用两个交替的数组彻底消除同一步内的读写依赖。 此外,整个 kernel 只用一个 block(线程数 = n),所有中间数据都放在片上共享内存,避免了反复读写全局内存的延迟。

先声明一个片上共享内存:

extern __shared__ float buf[]; // 动态共享内存,总大小由启动参数决定
int tid = threadIdx.x;
int pout = 0, pin = 1;
  • buf 的实际长度是在 kernel 启动时指定的:2 * n * sizeof(float)
  • 逻辑上,我们把它分成两个长度为 n 的缓冲区 buf[0..n-1]buf[n..2n-1],并通过索引偏移 pin * n / pout * n 来切换。

之后我们初始化 Exclusive Scan 的偏移

buf[pout * n + tid] = (tid > 0) ? in[tid - 1] : 0.0f;
__syncthreads();

因为要做exclusive scan,所以输出比输入整体右移移位,开头补0。 线程 0 写入 0.0f,线程 i (i>0) 写入 in[i-1]。这就得到了初始的“偏移数组”,准备开始迭代累加。 __syncthreads() 保证所有线程都写完了之后,才能进入主循环。

之后我们就开始迭代:

for (int offset = 1; offset < n; offset <<= 1) {
pout = 1 - pout; // 交换角色:上次的输出变这次的输入
pin = 1 - pout; // 这句等价于 pin = pout_old
// ① 拷贝一份到新的输出缓冲区
buf[pout * n + tid] = buf[pin * n + tid];
// ② 如果位置允许,加上距离为 offset 的元素
if (tid >= offset)
buf[pout * n + tid] += buf[pin * n + tid - offset];
__syncthreads(); // 步间同步
}
  • pout = 1 - pout; pin = 1 - pout;
    这是双缓冲的精髓。假设上一步 pout=0, pin=1,执行后 pout=1, pin=0
    意味着:上一轮的输出缓冲区 0 现在变成只读的输入,而新的输出写到缓冲区 1
    这样就彻底避免了“读和写发生在同一片内存”的危险。

  • 拷贝操作buf[pout*n+tid] = buf[pin*n+tid];
    先把输入缓冲区的当前值复制到输出缓冲区,相当于线程 tid 先继承自己上一轮的中间结果。

  • 累加前驱if (tid >= offset)
    如果线程索引不小于当前的 offset,它就需要加上距离为 offset 的那个元素(即 tid - offset 位置的值)。
    注意:这里读取的仍然是输入缓冲区 (pin),而不是正在写入的 pout
    因此即使多个线程同时执行,线程 tid 的读取源 tid - offset 不会被其他线程覆盖,没有同一步内的写后读(RAW)冲突

  • __syncthreads():在每一步结束时调用,确保所有线程都完成了对当前输出缓冲区的写入,才能安全地在下一步将其作为输入来读取。

循环结束后,pout 指向最后一步的输出缓冲区,直接将其写回 out 数组即可。

out[tid] = buf[pout * n + tid];

对于启动函数来说,我们只使用单个block,n个线程。动态内存大小刚好容纳两个长度为n的float数据。因此这个 kernel 受限于单 block 的最大线程数(通常为 1024)和共享内存大小(典型 48KB/64KB),只能处理较小规模的数组。

inline void launch_v0(const float* in, float* out, int n) {
scan_v0_naive<<<1, n, 2 * n * sizeof(float)>>>(in, out, n);
}

实际上这个方案不仅算的慢,还同时存在访存冲突和Warp Divergence。

  • 在本实现中,两个缓冲区都是连续存放的 float 数组,线程 tid 访问 buf[base + tid]
    n 是 32 的倍数时,tidtid-1(或 tid-offset)很可能映射到相同的 bank,造成 2 路甚至多路冲突。
    例如 offset=1 时,线程 0 读地址 0,线程 1 读地址 1……表面连续,但只要步长是 32 的约数,就可能同时有多个线程访问同一个 bank。(大家可以思考如何优化)
  • 对于if (tid >= offset) 会让一部分线程进入条件体,当 offset 小于 warp 大小(32)时,同一个 warp 内的线程会走不同分支,导致部分线程闲置。这是并行前缀和算法的固有特性,无法完全消除,但不会造成逻辑错误。

参考代码#

#pragma once
#include <cuda_runtime.h>
__global__ void scan_v0_naive(const float* in, float* out, int n) {
extern __shared__ float buf[];
int tid = threadIdx.x;
int pout = 0, pin = 1;
buf[pout * n + tid] = (tid > 0) ? in[tid-1] : 0.0f;
__syncthreads();
for (int offset = 1; offset < n; offset <<=1) {
pout = 1 - pout;
pin = 1 - pout;
buf[pout * n + tid] = buf[pin * n + tid];
if (tid >= offset) {
buf[pout * n + tid] += buf[pin * n + tid - offset];
}
__syncthreads();
}
out[tid] = buf[pout * n + tid];
}
inline void launch_v0(const float* in, float* out, int n) {
scan_v0_naive<<<1, n, 2 * n * sizeof(float)>>>(in, out, n);
}

2.2. Blelloch 算法#

为了解决工作效率问题,Blelloch 在1990年提出了一种基于树的工作高效算法,将操作次数降低到O(n)。

Blelloch算法示意图
Blelloch算法示意图

前缀和的本质是:

  • 位置 i 的前缀和 = 所有在它左边的元素之和。
  • 换句话说,如果把数组元素放在一棵二叉树的最底层(叶子),那么任何一段连续的叶子,其“前缀和”都可以由树中的某些部分和拼凑出来。

Blelloch 的第一个巧思就是:
先不要急着算前缀和,而是先构建一棵完整的“部分和”二叉树(归约树),再反向利用这棵树把前缀信息分发下去。

我们来看一下这棵树:

[0..7] ← 整棵树的总和
/ \
[0..3] [4..7]
/ \ / \
[0,1] [2,3] [4,5] [6,7]
/ \ / \ / \ / \
a b c d e f g h
  • 每个内部节点代表它覆盖的区间内所有叶子之和。
  • 如果我们已经拥有了这棵树上所有节点值,那么任何一个叶子的前缀和,都可以通过从根走到该叶子的路径上、某些节点的值累加得到。

这个算法把“前缀和”分解成“沿树路径累加部分和”。

第一阶段:Upsweep#

既然我们需要那棵部分和树,第一步当然就是把它算出来。
这其实就是并行归约,用规约那一套完整即可。

给定 stride offset(代表当前层相邻兄弟节点之间的距离),线程 thid 负责一对兄弟节点:

左子索引 = offset * (2*thid + 1) - 1
右子索引 = offset * (2*thid + 2) - 1

例如,当 stride=2 时(第二层,offset=2),线程 0 和 1 活跃:

  • thid=0:左子 = 2*(1)-1 = 1,右子 = 2*(2)-1 = 3 → 对应 [0,1] 和 [2,3] 的区间和
  • thid=1:左子 = 5,右子 = 7 → [4,5] 和 [6,7]

然后在每一层,只需执行:

temp[right] += temp[left]; // 右子 = 右子 + 左子

全部线程并行做完后,数组中就同时保存了叶子值和所有内部节点值,而且完全原地,没有额外内存开销。

第二阶段:Downsweep#

有了这棵树,下一个问题就是:如何利用它得到每个叶子的前缀和?

对于 exclusive scan,整个数组的“前缀和”起点是 0。所以先把根(代表全数组总和)的值清零,然后从上到下,把前缀信息传递给左右子节点。

具体规则(对应二叉树的性质):

  • 对于任何一个内部节点(父节点),它的值 parent 代表 它这个区间之前所有元素的和
  • 它的左子节点代表区间的左半部分:左半部分的前缀和,就是 parent(因为左半部分之前还是原来那些元素)。
  • 它的右子节点代表区间的右半部分:右半部分的前缀和,必须是 parent + 左半部分的总和(因为右半部分之前要加上整个左半部分)。

所以在 Downsweep 阶段,每一层的操作是:

float t = temp[left]; // 暂存左子原来的值(即左半部分的和)
temp[left] = temp[right]; // 左子获得父节点传来的前缀信息
temp[right] += t; // 右子 = 原来的右子 + 左半部分的和

注意 temp[right] 在这一层刚开始时,还保留着父节点传来的前缀值(因为父节点在上一层已经更新为正确的区间前缀)。通过这三行,前缀信息就沿着树一层层向下传播,最终到达每一个叶子

总结#

维度巧妙之处
算法复杂度总加法次数 ≈ 2n,是对数步数(log n)下的 O(n) 算法,性能接近顺序算法,却拥有巨大并行度。
树形分解将强顺序依赖的“前缀和”转化为“建树 + 分发”两个阶段,利用二叉树天然适合并行的特点。
就地操作整个计算完全在原数组上进行,不需要额外的树数据结构,索引只需简单的位运算公式。
统一框架Upsweep 和 Downsweep 的循环结构几乎镜像对称,一个自底向上,一个自顶向下,非常 GPU 友好。
可扩展性尽管这里演示的是单 block 版本,但该算法天然支持分级:先在多个 block 内做 block 内 scan,再用一个 block 做块间 scan,最后分发回各 block 修正,这就是几乎所有高性能 prefix sum 库(如 CUB、Thrust)的核心方法。

参考代码#

#pragma once
#include <cuda_runtime.h>
#include <iostream>
constexpr int NUM_BANKS = 32;
constexpr int LOG_NUM_BANKS = 5;
__device__ inline int pad(int i) {
return i + (i >> LOG_NUM_BANKS);
}
__global__ void scan_v1_blelloch(const float* in, float* out, int original_n, int n2) {
extern __shared__ float temp[];
int thid = threadIdx.x;
int offset = 1;
// 一个线程可以同时处理两个元素,提高线程本身的计算效率
int ai = thid;
int bi = thid + (n2 >> 1);
temp[pad(ai)] = (ai < original_n) ? in[ai] : 0.0f;
temp[pad(bi)] = (bi < original_n) ? in[bi] : 0.0f;
// ==================== Phase 1: Upsweep ====================
for (int d = n2 >> 1; d > 0; d >>= 1) {
// 确保上一轮都算完了
__syncthreads();
if (thid < d) {
int left_idx = offset * (2 * thid + 1) - 1;
int right_idx = offset * (2 * thid + 2) - 1;
temp[pad(right_idx)] += temp[pad(left_idx)];
}
offset <<= 1;
}
// ==================== Phase 2: Downsweep ====================
if (thid == 0) {
temp[pad(n2 - 1)] = 0.0f;
}
for (int d = 1; d < n2; d <<= 1) {
offset >>= 1;
__syncthreads();
if (thid < d) {
int left_idx = offset * (2 * thid + 1) - 1;
int right_idx = offset * (2 * thid + 2) - 1;
int pad_left = pad(left_idx);
int pad_right = pad(right_idx);
float t = temp[pad_left];
temp[pad_left] = temp[pad_right];
temp[pad_right] += t;
}
}
__syncthreads();
if (ai < original_n) out[ai] = temp[pad(ai)];
if (bi < original_n) out[bi] = temp[pad(bi)];
}
inline void launch_v1(const float* in, float* out, int n) {
if (n <= 0) return;
int n2 = 1;
while (n2 < n) n2 <<= 1;
if (n2 > 2048) {
printf("Error: Single block scan only supports up to 2048 elements.\n");
return;
}
int threads = (n2 / 2 > 0) ? (n2 / 2) : 1;
int padded_n = n2 + (n2 >> LOG_NUM_BANKS);
scan_v1_blelloch<<<1, threads, padded_n * sizeof(float)>>>(in, out, n, n2);
}

能够处理任意block的版本:

#pragma once
#include <cuda_runtime.h>
#include "scan_v1_blelloch.cuh"
constexpr int V2_BLK = 1024;
constexpr int V2_THR = V2_BLK / 2;
constexpr int PAD_V2_BLK = V2_BLK + (V2_BLK >> LOG_NUM_BANKS);
__global__ void scan_v2_tile(const float* in, float* out, float* aggs, int N) {
// 1. 加载 2 个元素到共享内存 (带 padding)
// 2. Upsweep:树状归约,算出段总和
// 3. 把总和存到 aggs[blockIdx.x],并把根节点清零
// 4. Downsweep:向下传播,得段内 exclusive scan
// 5. 写回 out
__shared__ float tmp[PAD_V2_BLK];
int tid = threadIdx.x;
// 每线程负责的两个本地索引
int li = tid;
int ri = tid + V2_BLK / 2;
// 对应的全局索引
int gi = blockIdx.x * V2_BLK + li;
int gj = blockIdx.x * V2_BLK + ri;
tmp[pad(li)] = (gi < N) ? in[gi] : 0.0f;
tmp[pad(ri)] = (gj < N) ? in[gj] : 0.0f;
int offset = 1;
// 上扫
for (int d = V2_BLK >> 1; d > 0; d >>= 1) {
__syncthreads();
if (tid < d) {
int left = offset * (2 * tid + 1) - 1;
int right = offset * (2 * tid + 2) - 1;
tmp[pad(right)] += tmp[pad(left)];
}
offset <<= 1;
}
__syncthreads();
if (tid == 0) {
aggs[blockIdx.x] = tmp[pad(V2_BLK - 1)]; // 存下 tile 总和
tmp[pad(V2_BLK - 1)] = 0.0f; // 根节点清零
}
// 下扫
for (int d = 1; d < V2_BLK; d <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < d) {
int left = offset * (2 * tid + 1) - 1;
int right = offset * (2 * tid + 2) - 1;
float t = tmp[pad(left)];
tmp[pad(left)] = tmp[pad(right)];
tmp[pad(right)] += t;
}
}
__syncthreads();
if (gi < N) out[gi] = tmp[pad(li)];
if (gj < N) out[gj] = tmp[pad(ri)];
}
__global__ void scan_v2_add(const float* aggs, float* out, int N) {
// 把 aggs[blockIdx.x] 加到 out 数组上,段内所有元素都加同一个值。
int tid = threadIdx.x;
int gi = blockIdx.x * V2_BLK + tid;
int gj = gi + V2_BLK / 2;
float add = aggs[blockIdx.x]; // 该 tile 需要加上的前缀
if (gi < N) out[gi] += add;
if (gj < N) out[gj] += add;
}
inline void launch_v2_prefix(float* data, int n) {
// 终止条件:问题已缩小到一个 tile 能容纳
if (n <= V2_BLK) {
float* dummy;
cudaMalloc(&dummy, sizeof(float)); // 不需要 aggs,但 kernel 需要这个参数,给个 dummy
scan_v2_tile<<<1, V2_THR>>>(data, data, dummy, n);
cudaFree(dummy);
return;
}
// 1. 分块:计算需要多少个完整 tile
int tiles = (n + V2_BLK - 1) / V2_BLK; // 这里是变化的关键!
float* totals;
cudaMalloc(&totals, tiles * sizeof(float));
// 2. 对每个 tile 做段内扫描,同时收集每个 tile 的总和到 totals
scan_v2_tile<<<tiles, V2_THR>>>(data, data, totals, n);
// 3. 递归!对 totals 本身求前缀和
launch_v2_prefix(totals, tiles);
// 4. 将 totals 中的前缀广播回各个 tile
scan_v2_add<<<tiles, V2_THR>>>(totals, data, n);
cudaFree(totals);
}
inline void launch_v2(float* d_in, float* d_out, int n) {
// 先拷贝输入到输出数组,之后就地计算
cudaMemcpy(d_out, d_in, n * sizeof(float), cudaMemcpyDeviceToDevice);
launch_v2_prefix(d_out, n);
}
大家可以考虑如何使用shuffle来做优化?

3. Triton和Pytorch#

对于Pytorch来说,调用的是CUB;而对于Triton来说,提供了Block级别的算子,但是在Block之间,需要我们手动来实现,参考如下:

import torch
import triton
import triton.language as tl
@triton.jit
def combine(a, b): return a + b
# Pass 1: 仅计算每个 Block 的总和
@triton.jit
def reduce_kernel(x_ptr, sums_ptr, n, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets, mask=offsets < n, other=0.0)
tl.store(sums_ptr + pid, tl.reduce(x, 0, combine))
# Pass 2: 计算 Block 级别的全局偏移量
@triton.jit
def scan_sums_kernel(sums_ptr, offsets_ptr, num_blocks, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < num_blocks
sums = tl.load(sums_ptr + offsets, mask=mask, other=0.0)
scanned = tl.associative_scan(sums, axis=0, combine_fn=combine)
tl.store(offsets_ptr + offsets, scanned - sums, mask=mask)
# Pass 3: 在寄存器中局部 Scan 并加上偏移量
@triton.jit
def scan_add_kernel(x_ptr, y_ptr, offsets_ptr, n, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
scanned = tl.associative_scan(x, axis=0, combine_fn=combine)
offset = tl.load(offsets_ptr + pid) # 读取此 block 的全局前缀偏移
tl.store(y_ptr + offsets, scanned + offset, mask=mask)
def scan(x: torch.Tensor) -> torch.Tensor:
n = x.numel()
y = torch.empty_like(x)
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(n, BLOCK_SIZE)
sums = torch.empty(num_blocks, device=x.device, dtype=x.dtype)
offsets = torch.empty(num_blocks, device=x.device, dtype=x.dtype)
grid = (num_blocks,)
reduce_kernel[grid](x, sums, n, BLOCK_SIZE)
scan_bs = max(16, triton.next_power_of_2(num_blocks))
scan_sums_kernel[(1,)](sums, offsets, num_blocks, scan_bs)
scan_add_kernel[grid](x, y, offsets, n, BLOCK_SIZE)
return y
if __name__ == "__main__":
x = torch.randn(10000, device='cuda')
y_triton = scan(x)
max_error = (y_triton - torch.cumsum(x, dim=0)).abs().max().item()
print(f"Max error: {max_error:.6e}\nOptimized Scan passed!")

支持与分享

如果这篇文章对你有帮助,欢迎分享给更多人或赞助支持!

赞助
CUDA学习之路[10]:扫描算法详解
https://dlog.com.cn/posts/cuda10/scan/
作者
杜子源
发布于
2026-05-13
许可协议
CC BY-NC-SA 4.0
最后更新于 2026-05-13,距今已过 44 天

部分内容可能已过时

Profile Image of the Author
杜子源
都是风景,幸会
公告
请狠狠地打赏我,打赏一次,爆更一篇!!
音乐
封面

音乐

暂未播放

0:00 0:00
暂无歌词
分类
标签
站点统计
文章
29
分类
8
标签
11
总字数
81,272
运行时长
0
最后活动
0 天前

目录