FlashAttention/zh

    From Marovi AI
    This page is a translated version of the page FlashAttention and the translation is 100% complete.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Self-Attention, Softmax, Transformer


    概述

    FlashAttention 是一种精确且 I/O 感知 的算法,用于计算 Transformer 模型中使用的 自注意力 运算。它由 Tri Dao 等人于 2022 年提出,输出在数值上与标准注意力实现等价,但运行速度可达数倍以上,并且内存占用与序列长度成线性关系,而不是二次关系。其核心洞察是:在现代加速器(如 GPU)上,注意力的瓶颈在于高带宽内存(HBM)与片上 SRAM 之间的内存带宽,而不是浮点吞吐量。通过对计算进行分块(tiling),使中间结果始终保留在 SRAM 中,FlashAttention 降低了 HBM 流量,并缩短了所有基于 Transformer 系统的实际运行时间——从 BERT 到大型 GPT 系列 语言模型 都受益于此。[1]

    该算法已成为生产环境中 Transformer 训练与推理栈的事实标准组件;其后续版本(FlashAttention-2 和 FlashAttention-3)进一步带来了更好的并行性,以及针对新一代 GPU 的硬件专用优化。

    为何标准注意力运行缓慢

    在一个 Transformer 中,缩放点积注意力的计算如下

    $ {\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\tfrac{Q K^{\top}}{\sqrt{d}}\right) V} $

    其中 $ Q, K, V \in \mathbb{R}^{N \times d} $ 分别是查询、键和值矩阵,$ N $ 是序列长度,$ d $ 是每个头的维度。朴素实现会先构造分数矩阵 $ S = Q K^{\top}/\sqrt{d} $,对其按行应用 softmax 得到 $ P $,再与 $ V $ 相乘。$ S $$ P $ 都是 $ N \times N $ 大小,必须写入并读出 HBM。

    对于长上下文,这种 $ O(N^2) $ 的数据流量主导了运行时间。在 NVIDIA A100 GPU 上,HBM 的带宽大约比 SRAM 低一个数量级,因此注意力计算大部分周期都在等待内存。早期的“快速注意力”方法,如 稀疏注意力线性注意力,虽然降低了 FLOPs,但往往并未缩短实际运行时间,因为它们没有解决这一内存瓶颈。

    在线 softmax 技巧

    FlashAttention 依赖于一种数值稳定的、增量式的 softmax 形式,通常称为在线 softmax。[2] 给定一串值 $ x_1, x_2, \dots, x_N $,可以维护一个滚动最大值 $ m $ 和一个滚动归一化因子 $ \ell $,使得在看完所有值后:

    $ {\displaystyle \mathrm{softmax}(x)_i = \frac{e^{x_i - m}}{\ell}, \qquad m = \max_i x_i, \qquad \ell = \sum_i e^{x_i - m}.} $

    同样的恒等式可推广到注意力所需的加权和 $ \sum_i \mathrm{softmax}(x)_i v_i $:当一个新的分数块带着局部最大值 $ m' $ 到来时,先把之前的部分输出按 $ e^{m_{\text{old}} - m_{\text{new}}} $ 重新缩放,再累加新的贡献,并相应地更新归一化因子。完整的 softmax 始终无需在内存中显式构造。

    FlashAttention 算法

    FlashAttention 以分块方式应用在线 softmax,在分块矩阵上进行计算。前向传递的步骤如下。

    1. $ Q $ 划分为每块 $ B_r $ 行,将 $ K, V $ 划分为每块 $ B_c $ 行,块大小要使工作集恰好能放入 SRAM。
    2. 对每个查询块 $ Q_i $,初始化输出块 $ O_i = 0 $、滚动最大值 $ m_i = -\infty $ 和归一化因子 $ \ell_i = 0 $
    3. 遍历键/值块 $ (K_j, V_j) $。将 $ K_j $$ V_j $ 载入 SRAM,计算分数块 $ S_{ij} = Q_i K_j^{\top}/\sqrt{d} $、局部最大值 $ m_{ij} $ 及局部指数 $ \tilde P_{ij} = e^{S_{ij} - m_{ij}} $
    4. 与累计统计量合并。更新 $ m_i^{\text{new}} = \max(m_i, m_{ij}) $,将 $ O_i $$ \ell_i $$ e^{m_i - m_i^{\text{new}}} $ 重新缩放,并累加 $ O_i \mathrel{+}= \tilde P_{ij} V_j \cdot e^{m_{ij} - m_i^{\text{new}}} $$ \ell_i \mathrel{+}= \mathbf{1}^{\top} \tilde P_{ij} \cdot e^{m_{ij} - m_i^{\text{new}}} $
    5. 处理完所有键/值块后,把 $ O_i $ 除以 $ \ell_i $ 并写回 HBM,同时保存 log-sum-exp $ L_i = m_i + \log \ell_i $(供反向传递使用)。

    完整的 $ N \times N $ 矩阵从不在 HBM 中显式构造,因此内存占用从 $ O(N^2) $ 降到 $ O(N) $。HBM 访问次数从 $ O(N^2 d) $ 降到 $ O(N^2 d^2 / M) $,其中 $ M $ 是 SRAM 大小,加速大部分由此而来。

    反向传递与重计算

    标准反向传递需要注意力矩阵 $ P $,如果将其保存下来就会抵消 FlashAttention 的意义。FlashAttention 改为在反向传播过程中即时重新计算 $ S $$ P $,仅使用先前保存的 log-sum-exp $ L $反向梯度 $ \mathrm{d}O $。由于在重计算阶段分数块都驻留在 SRAM 中,重计算开销很低;在内存受限的硬件上,用更多 FLOPs 换取更少的 HBM 读取是一笔划算的交易。整体训练步比朴素实现更快,即使进行了更多运算。

    变体与后续版本

    FlashAttention 已发布三个版本。

    • FlashAttention-1(2022 年)引入了上述 I/O 感知的分块策略,并以单个融合的 CUDA 内核实现。它支持因果掩码和 dropout,在 FP16 下按标准归约顺序保持逐位等价。
    • FlashAttention-2(2023 年)重新组织了工作划分,将外层循环改为遍历查询块、内层循环遍历键/值块,从而提高了线程块和 warp 之间的并行度。它还减少了非矩阵乘法的 FLOPs,并在长序列上提升了利用率,在 A100 上达到理论峰值 FLOPs 的约 50–73%。[3]
    • FlashAttention-3(2024 年)面向 Hopper 架构,使用异步 Tensor Core、张量内存加速器(TMA)和 warp 专用化来重叠 softmax 与矩阵乘法工作。它还支持 FP8,并通过分块缩放降低量化误差。[4]

    相关思路还包括 xFormers 中的 Memory-Efficient Attention,它独立得出了类似的分块方法;以及 vLLM 使用的 PagedAttention,它在推理阶段对 KV 缓存进行分块,从而支持较大的批大小和大量并发序列。

    与其他快速注意力方法的比较

    PerformerLinformer 等近似方法不同,FlashAttention 计算的是精确注意力。它不改变模型的输出或训练动态,因此可以无需重新训练就替换到任何现有 Transformer 中。其加速效果与改变注意力渐近代价的技术叠加生效,因为 FlashAttention 也可以与稀疏或带状掩码组合使用。在实际工作负载中,只要可接受精确注意力,FlashAttention 已经取代了近似方法,因为它运行更快、内存占用更低,且数值结果完全相同。

    局限性

    FlashAttention 的收益取决于 SRAM 与 HBM 带宽之间的差距,以及内核经过精心手工调优的调度策略。该算法对中等到较长的序列、最高约 256 的头维度和标准掩码模式收益最大。自定义的注意力变体——特殊的掩码、可学习的偏置或非标准的评分函数——可能需要新的 内核,或退回到较慢的路径。数值结果只在内核所选的归约顺序下保持逐位等价;在 FP8 等低精度格式下,需要谨慎缩放才能保持精度

    参考文献

    1. Template:Cite arxiv
    2. Milakov, M. and Gimelshein, N., Online Normalizer Calculation for Softmax, 2018.
    3. Template:Cite arxiv
    4. Template:Cite arxiv