Skip to content

FlashAttention

FlashAttention 是一系列通过重新设计注意力计算的内存访问模式来显著加速 Transformer 训练与推理的算法。它的核心思想很简单:减少 GPU 高速存储(HBM)与低速存储(SRAM)之间的数据传输。这一优化使得长上下文的大语言模型训练成为可能,并且推理速度提升 2-4 倍。

Definition

标准的注意力计算中,查询矩阵 Q、键矩阵 K、值矩阵 V 的计算过程为:

Attention(Q, K, V) = softmax(QK^T / √d_k) · V

传统实现的问题在于:

  1. 需要将巨大的 N×N 注意力矩阵 写入 HBM(高带宽内存)
  2. 每个操作都涉及多次 HBM 读写,而 HBM 带宽仅为 SRAM 的 1/10 ~ 1/100
  3. 当序列长度 N 增大时,内存占用按 O(N²) 增长

FlashAttention 的核心突破:不存储完整的 N×N 注意力矩阵,而是通过分块(tiling)和重算(recomputation)策略,在 SRAM 中完成全部计算。

核心机制

IO-Aware 计算

FlashAttention 的设计理念来自系统领域的 IO-Awareness——计算不仅要关注 FLOPs,更要关注数据在存储层次间的移动。

存储层次带宽容量特点
SRAM (共享存储)~19 TB/s~100 KB极快但极小
HBM (高带宽内存)~1.5 TB/s10-80 GB快但比 SRAM 慢 10-100 倍

FlashAttention 的目标是:让注意力计算尽可能在 SRAM 中完成,减少 HBM 访问。

分块(Tiling)策略

将 Q、K、V 分割成较小的块(tiles),每个块能够装入 SRAM:

  1. 将 Q 分割为 B_r 行的块,K 和 V 分割为 B_c 行的块
  2. 加载一个 Q 块和一个 K/V 块到 SRAM
  3. 在 SRAM 中计算部分注意力分数和输出
  4. 重复直到所有块处理完毕

重算(Recomputation)策略

传统方法会存储中间激活值(用于反向传播)。FlashAttention 选择不存储中间注意力矩阵,而是在反向传播时重新计算前向传播的激活值。

  • 这增加了约 20% 的 FLOPs
  • 但减少了 5-10 倍的 HBM 访问
  • 总体加速 2-4 倍

版本演进

FlashAttention-1 (2022)

Tri Dao 等人提出的原始版本:

  • 首次引入 IO-Aware 的注意力计算
  • 支持前向和反向传播
  • 在 A100 上加速 2-4 倍

FlashAttention-2 (2023)

主要改进:

  • 更好的并行度:减少非必要的 thread block 同步
  • 更高的占用率:更多 warps 可以同时运行
  • 支持头部并行(multi-head parallelism)
  • 训练速度达到 A100 理论峰值的 70-80%

FlashAttention-3 (2024)

针对 Hopper 架构(H100)优化:

  • 利用 Tensor Core 的 asynchronous execution
  • 支持 FP8 精度
  • 集成 PagedAttention(与 vLLM 兼容)
  • 在 H100 上达到 1.5-2 倍于 FlashAttention-2 的性能

对 AI 生态的影响

长上下文模型的可行性

FlashAttention 使得训练长上下文模型成为经济可行的方案:

上下文长度传统注意力内存FlashAttention 内存差异
2K16 MB4 MB
8K256 MB16 MB16×
32K4 GB64 MB64×
128K64 GB256 MB256×

这直接支持了:

推理加速

在推理阶段,FlashAttention 与 KV Cache & Prompt Caching 结合使用:

  • 减少自回归生成中的内存开销
  • 支持更大的批处理维度
  • vLLM 等推理框架的核心组件

与其他技术的关系

技术关系
KV CacheFlashAttention 减少了 KV Cache 的内存占用,使更长上下文成为可能
PagedAttentionvLLM 将 FlashAttention 与分页式内存管理结合,支持更大的并发
Speculative DecodingFlashAttention 的高效推理为草稿模型提供了基础
Ring Attention将 FlashAttention 扩展到多机分布式训练,支持更长上下文

实际应用

  • 训练框架: PyTorch 2.0+ 内置 torch.nn.functional.scaled_dot_product_attention,默认使用 FlashAttention
  • 推理框架: vLLM、TensorRT-LLM、TGI 均集成了 FlashAttention
  • 主流模型: Llama 2/3、GPT-4、Claude 等的训练和推理均依赖 FlashAttention

Synthesis

FlashAttention 是 Transformer 时代最重要的工程优化之一。它证明了一个关键原理:在现代 GPU 架构下,内存访问模式比纯粹的计算量更重要。这一认识影响了后续所有大规模模型的设计——从更长的上下文窗口到更高的推理并发,FlashAttention 都是不可或缺的基础。

Sources

  • Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS.
  • Dao (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning."
  • Shah et al. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision."
  • FlashAttention GitHub: https://github.com/Dao-AILab/flash-attention

AI Knowledge Base — 持续积累