Attention Rollout/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Transformer, Attention Mechanism, Self-Attention |
概述
注意力 rollout(attention rollout)是一種用於量化信息如何從每個輸入 token 流向每個輸出位置的技術,方法是遞歸地相乘 Transformer 每一層的注意力權重矩陣。該方法由 Abnar 和 Zuidema 於 2020 年提出,[1]用於解決一個眾所周知的問題:單個層中原始的注意力權重並不能很好地代表每個輸入 token 對給定預測的貢獻,因為在每一層中,token 都通過 Self-Attention 反覆與其鄰居混合。
Rollout 將各層堆疊視為一個圖,並計算從任意輸出節點回溯到輸入的路徑加權可達性。所得的矩陣被廣泛用作基於 transformer 的模型的顯著性圖,特別是 Vision Transformers(ViTs)和 BERT 風格的編碼器,其中它是可解釋性和無需重新訓練即可產生類判別式定位的標準工具之一。
動機
單層注意力權重作為解釋具有直觀的吸引力:位置 $ i $ 以權重 $ a_{ij} $ "查看" 位置 $ j $。然而,在深層 transformer 中,第 $ \ell $ 層的表示已經是所有輸入 token 的混合,因此後期某層中 $ a_{ij} $ 較大並不意味着位置 $ i $ 的預測主要依賴於原始輸入 token $ j $。多項研究認為,"注意力不是解釋",正是因為這種跨層污染。[2][3]
注意力 rollout 通過提出一個不同的問題來解決該問題:如果我們沿網絡向後追蹤信息,那麼給定輸出位置的表示中有多少源自每個輸入 token?答案需要將逐層的注意力矩陣組合起來,而不是孤立地檢查它們。
公式化
設置
考慮一個具有 $ L $ 層、序列長度為 $ n $ 且每層 $ H $ 個注意力頭的 Transformer。設 $ A^{(\ell)} \in \mathbb{R}^{n \times n} $ 表示第 $ \ell $ 層的注意力權重矩陣,其中行對應於查詢位置,列對應於鍵位置,且每行之和為 1。當該層有多個頭時,先對每個頭的矩陣 $ A^{(\ell, h)} $ 求平均:
$ {\displaystyle A^{(\ell)} = \frac{1}{H} \sum_{h=1}^{H} A^{(\ell, h)}.} $
有時也使用其他歸約方式,例如對各頭取最大值或學得的加權和。
殘差修正
純自注意力層並不是穿過 transformer 塊的唯一路徑:殘差連接將該層的輸入加到其輸出上,因此每個位置上信號中有一個不可忽略的部分實際上就是該位置自身先前的表示。為了反映這一點,rollout 將原始的注意力矩陣替換為
$ {\displaystyle \tilde{A}^{(\ell)} = \frac{1}{2}\bigl(A^{(\ell)} + I\bigr),} $
其中 $ I $ 是單位矩陣。係數 1/2 保持了行隨機性:$ \tilde{A}^{(\ell)} $ 的每一行之和仍為 1,因此它可以被解釋為關於輸入位置的概率分佈。單位項編碼了殘差流的貢獻。
遞歸乘積
第 $ \ell $ 層的 rollout 矩陣是累乘
$ {\displaystyle R^{(\ell)} = \tilde{A}^{(\ell)} \, \tilde{A}^{(\ell-1)} \cdots \tilde{A}^{(1)} = \prod_{k=\ell}^{1} \tilde{A}^{(k)}.} $
條目 $ R^{(L)}_{ij} $ 被解釋為輸出位置 $ i $ 的表示中,通過組合的注意力路徑可歸因於輸入位置 $ j $ 的比例。由於每個 $ \tilde{A}^{(\ell)} $ 都是行隨機的,因此 $ R^{(L)} $ 也是行隨機的,從而可以直接將結果可視化為輸入 token 上的熱力圖。
對於使用特殊 CLS token 的分類 ViT 或 BERT 風格模型,$ R^{(L)} $ 中對應於 CLS 索引的那一行給出了關於輸入 patch 或 token、針對預測類別的顯著性圖。
算法
在任何能夠暴露其注意力權重的 transformer 之上,該過程都很容易實現:
- 進行一次前向傳播,並為每一層和每個頭保存 $ A^{(\ell, h)} $。
- 對每一層,對各頭取平均以獲得 $ A^{(\ell)} $。
- 加上單位矩陣並重新歸一化,得到 $ \tilde{A}^{(\ell)} $。
- 初始化 $ R \leftarrow I $,並從第一層到最後一層依次執行 $ R \leftarrow \tilde{A}^{(\ell)} R $。
- 取出感興趣的那一行(通常是 CLS),並重塑為輸入網格的形狀。
總開銷為 $ O(L n^2) $ 次矩陣-矩陣乘積,主要由逐層的乘法主導。對於圖像分類中典型的較短序列長度,整個 rollout 相對於一次前向傳播的開銷非常小。
注意力流
Abnar 和 Zuidema 的原始論文還提出了一種相關但成本更高的變體,稱為注意力流(attention flow)。它不再相乘逐層的矩陣,而是將注意力圖視為一個帶容量限制的有向圖,並計算從每個輸入節點到每個輸出節點的最大流。流更忠實地反映了網絡中的瓶頸,因為相乘會重複計入共享某條邊的路徑,但其代價在 $ n $ 上為超線性,因此很少在大規模場景中使用。在實踐中,rollout 是主流變體。
應用
注意力 rollout 已成為多種模型家族的默認可視化工具:
- Vision Transformers:將 rollout 熱力圖限制在 $ R^{(L)} $ 的 CLS 行,並重塑為 patch 網格後,可產生類判別式定位結果,其性能與 Grad-CAM 等基於梯度的方法相當,[4]在 ImageNet 和弱監督分割基準上具有競爭力。[5]
- 語言模型:rollout 用於檢查 BERT 或 RoBERTa 編碼器在給定預測中依賴哪些輸入 token,作為逐頭探針研究的補充。
- 多模態 transformer:在 CLIP 風格和圖文模型中,限制在交叉注意力層上的 rollout 矩陣可以揭示哪些圖像區域支撐某個文本 token,從而支持開放詞彙分割和定位。
- 模型審計:通過比較微調步驟或領域偏移前後的 rollout 圖,從業者可以檢測模型是否將其依賴從某個輸入區域轉移到了另一個區域。
變體與擴展
基本 rollout 的若干擴展可提高其忠實度或類別特異性:
- 梯度加權 rollout:將注意力矩陣與預測類別得分對注意力值的梯度相結合,在 rollout 乘積中將 $ A^{(\ell)} $ 替換為在相關類別下評估得到的 $ (\nabla_A y)\odot A^{(\ell)} $。Chefer 等人表明,這種方法在 ViTs 中比原始 rollout 產生更清晰、更具類判別性的圖。[6]
- 按頭分別計算的 rollout:不再對各頭取平均,而是按頭分別計算 rollout 乘積並對結果進行匯總,從而揭示各個頭的作用。
- 編碼器-解碼器 rollout:對於翻譯和seq2seq模型,rollout 沿編碼器自注意力鏈、交叉注意力層和解碼器自注意力鏈分別計算,並將矩陣端到端組合。
- 稀疏與剪枝 rollout:在相乘之前丟棄低於閾值的零項,以使可視化聚焦於主導路徑。
- Top-k 注意力 rollout:在組合之前,對每個 $ \tilde{A}^{(\ell)} $ 的每一行只保留 top-k 項,其動機是觀察到注意力分佈通常具有重尾。
與其他顯著性方法的比較
Rollout 是應用於 transformer 的若干解釋工具家族之一:
- 與原始注意力相比,rollout 修正了多層混合帶來的偏差,而正是這種混合使得單層注意力具有誤導性。
- 與 Grad-CAM 和其他基於梯度的方法相比,rollout 完全基於前向傳播,無需對預測進行求導;這使其計算成本低且與架構無關,但若不引入梯度,類別特異性也較弱。
- 與 積分梯度 和 SHAP 風格的方法相比,rollout 速度快得多(只需少量矩陣乘積,而非大量前向傳播),但不滿足相同的公理保證。
- 與注意力流相比,rollout 是一種易於計算的近似,在大多數情況下產生定性相似的顯著性圖。
局限性
將注意力 rollout 用作解釋時存在若干注意事項:
- 未使用值路徑:rollout 只檢查注意力權重,而忽略實際承載內容的值投影。兩個具有相同注意力模式但值矩陣差異很大的頭會產生無法區分的 rollout。
- 單位修正只是啟發式:對 $ A $ 和 $ I $ 等權重的處理只是一種約定,而非推導得出的量。真正的殘差佔比取決於殘差和注意力貢獻的相對範數,而這在不同層和不同 token 之間會有所差異。
- 類別無關:原始 rollout 給出的圖與預測類別無關;需要梯度加權變體才能得到類判別式解釋。
- 忠實度爭議:對基於注意力的解釋進行的形式化評估結果不一,rollout 也繼承了這些批評。使用者應將 rollout 圖視為提供信息的摘要,而非真值歸因。
- 架構假設:標準形式的 rollout 假設存在一堆相同的、帶有殘差連接的注意力塊。將其適配到專家混合、稀疏注意力或基於路由的 transformer 時需要謹慎。
參考文獻
- ↑ Template:Cite arxiv
- ↑ Jain, S. and Wallace, B. C., "Attention is not Explanation", NAACL 2019.
- ↑ Wiegreffe, S. and Pinter, Y., "Attention is not not Explanation", EMNLP 2019.
- ↑ Selvaraju, R. R. et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization", ICCV 2017.
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv