FlashAttention/zh

    From Marovi AI
    This page is a translated version of the page FlashAttention and the translation is 100% complete.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Self-Attention, Softmax, Transformer


    概述

    FlashAttention 是一種精確且 I/O 感知 的算法,用於計算 Transformer 模型中使用的 自注意力 運算。它由 Tri Dao 等人於 2022 年提出,輸出在數值上與標準注意力實現等價,但運行速度可達數倍以上,並且內存佔用與序列長度成線性關係,而不是二次關係。其核心洞察是:在現代加速器(如 GPU)上,注意力的瓶頸在於高帶寬內存(HBM)與片上 SRAM 之間的內存帶寬,而不是浮點吞吐量。通過對計算進行分塊(tiling),使中間結果始終保留在 SRAM 中,FlashAttention 降低了 HBM 流量,並縮短了所有基於 Transformer 系統的實際運行時間——從 BERT 到大型 GPT 系列 語言模型 都受益於此。[1]

    該算法已成為生產環境中 Transformer 訓練與推理棧的事實標準組件;其後續版本(FlashAttention-2 和 FlashAttention-3)進一步帶來了更好的並行性,以及針對新一代 GPU 的硬件專用優化。

    為何標準注意力運行緩慢

    在一個 Transformer 中,縮放點積注意力的計算如下

    $ {\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\!\left(\tfrac{Q K^{\top}}{\sqrt{d}}\right) V} $

    其中 $ Q, K, V \in \mathbb{R}^{N \times d} $ 分別是查詢、鍵和值矩陣,$ N $ 是序列長度,$ d $ 是每個頭的維度。樸素實現會先構造分數矩陣 $ S = Q K^{\top}/\sqrt{d} $,對其按行應用 softmax 得到 $ P $,再與 $ V $ 相乘。$ S $$ P $ 都是 $ N \times N $ 大小,必須寫入並讀出 HBM。

    對於長上下文,這種 $ O(N^2) $ 的數據流量主導了運行時間。在 NVIDIA A100 GPU 上,HBM 的帶寬大約比 SRAM 低一個數量級,因此注意力計算大部分周期都在等待內存。早期的「快速注意力」方法,如 稀疏注意力線性注意力,雖然降低了 FLOPs,但往往並未縮短實際運行時間,因為它們沒有解決這一內存瓶頸。

    在線 softmax 技巧

    FlashAttention 依賴於一種數值穩定的、增量式的 softmax 形式,通常稱為在線 softmax。[2] 給定一串值 $ x_1, x_2, \dots, x_N $,可以維護一個滾動最大值 $ m $ 和一個滾動歸一化因子 $ \ell $,使得在看完所有值後:

    $ {\displaystyle \mathrm{softmax}(x)_i = \frac{e^{x_i - m}}{\ell}, \qquad m = \max_i x_i, \qquad \ell = \sum_i e^{x_i - m}.} $

    同樣的恆等式可推廣到注意力所需的加權和 $ \sum_i \mathrm{softmax}(x)_i v_i $:當一個新的分數塊帶着局部最大值 $ m' $ 到來時,先把之前的部分輸出按 $ e^{m_{\text{old}} - m_{\text{new}}} $ 重新縮放,再累加新的貢獻,並相應地更新歸一化因子。完整的 softmax 始終無需在內存中顯式構造。

    FlashAttention 算法

    FlashAttention 以分塊方式應用在線 softmax,在分塊矩陣上進行計算。前向傳遞的步驟如下。

    1. $ Q $ 劃分為每塊 $ B_r $ 行,將 $ K, V $ 劃分為每塊 $ B_c $ 行,塊大小要使工作集恰好能放入 SRAM。
    2. 對每個查詢塊 $ Q_i $,初始化輸出塊 $ O_i = 0 $、滾動最大值 $ m_i = -\infty $ 和歸一化因子 $ \ell_i = 0 $
    3. 遍歷鍵/值塊 $ (K_j, V_j) $。將 $ K_j $$ V_j $ 載入 SRAM,計算分數塊 $ S_{ij} = Q_i K_j^{\top}/\sqrt{d} $、局部最大值 $ m_{ij} $ 及局部指數 $ \tilde P_{ij} = e^{S_{ij} - m_{ij}} $
    4. 與累計統計量合併。更新 $ m_i^{\text{new}} = \max(m_i, m_{ij}) $,將 $ O_i $$ \ell_i $$ e^{m_i - m_i^{\text{new}}} $ 重新縮放,並累加 $ O_i \mathrel{+}= \tilde P_{ij} V_j \cdot e^{m_{ij} - m_i^{\text{new}}} $$ \ell_i \mathrel{+}= \mathbf{1}^{\top} \tilde P_{ij} \cdot e^{m_{ij} - m_i^{\text{new}}} $
    5. 處理完所有鍵/值塊後,把 $ O_i $ 除以 $ \ell_i $ 並寫回 HBM,同時保存 log-sum-exp $ L_i = m_i + \log \ell_i $(供反向傳遞使用)。

    完整的 $ N \times N $ 矩陣從不在 HBM 中顯式構造,因此內存佔用從 $ O(N^2) $ 降到 $ O(N) $。HBM 訪問次數從 $ O(N^2 d) $ 降到 $ O(N^2 d^2 / M) $,其中 $ M $ 是 SRAM 大小,加速大部分由此而來。

    反向傳遞與重計算

    標準反向傳遞需要注意力矩陣 $ P $,如果將其保存下來就會抵消 FlashAttention 的意義。FlashAttention 改為在反向傳播過程中即時重新計算 $ S $$ P $,僅使用先前保存的 log-sum-exp $ L $反向梯度 $ \mathrm{d}O $。由於在重計算階段分數塊都駐留在 SRAM 中,重計算開銷很低;在內存受限的硬件上,用更多 FLOPs 換取更少的 HBM 讀取是一筆划算的交易。整體訓練步比樸素實現更快,即使進行了更多運算。

    變體與後續版本

    FlashAttention 已發佈三個版本。

    • FlashAttention-1(2022 年)引入了上述 I/O 感知的分塊策略,並以單個融合的 CUDA 內核實現。它支持因果掩碼和 dropout,在 FP16 下按標準歸約順序保持逐位等價。
    • FlashAttention-2(2023 年)重新組織了工作劃分,將外層循環改為遍歷查詢塊、內層循環遍歷鍵/值塊,從而提高了線程塊和 warp 之間的並行度。它還減少了非矩陣乘法的 FLOPs,並在長序列上提升了利用率,在 A100 上達到理論峰值 FLOPs 的約 50–73%。[3]
    • FlashAttention-3(2024 年)面向 Hopper 架構,使用異步 Tensor Core、張量內存加速器(TMA)和 warp 專用化來重疊 softmax 與矩陣乘法工作。它還支持 FP8,並通過分塊縮放降低量化誤差。[4]

    相關思路還包括 xFormers 中的 Memory-Efficient Attention,它獨立得出了類似的分塊方法;以及 vLLM 使用的 PagedAttention,它在推理階段對 KV 緩存進行分塊,從而支持較大的批大小和大量並發序列。

    與其他快速注意力方法的比較

    PerformerLinformer 等近似方法不同,FlashAttention 計算的是精確注意力。它不改變模型的輸出或訓練動態,因此可以無需重新訓練就替換到任何現有 Transformer 中。其加速效果與改變注意力漸近代價的技術疊加生效,因為 FlashAttention 也可以與稀疏或帶狀掩碼組合使用。在實際工作負載中,只要可接受精確注意力,FlashAttention 已經取代了近似方法,因為它運行更快、內存佔用更低,且數值結果完全相同。

    局限性

    FlashAttention 的收益取決於 SRAM 與 HBM 帶寬之間的差距,以及內核經過精心手工調優的調度策略。該算法對中等到較長的序列、最高約 256 的頭維度和標準掩碼模式收益最大。自定義的注意力變體——特殊的掩碼、可學習的偏置或非標準的評分函數——可能需要新的 內核,或退回到較慢的路徑。數值結果只在內核所選的歸約順序下保持逐位等價;在 FP8 等低精度格式下,需要謹慎縮放才能保持精度

    參考文獻

    1. Template:Cite arxiv
    2. Milakov, M. and Gimelshein, N., Online Normalizer Calculation for Softmax, 2018.
    3. Template:Cite arxiv
    4. Template:Cite arxiv