Attention Rollout/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Transformer, Attention Mechanism, Self-Attention |
概述
注意力 rollout(attention rollout)是一种用于量化信息如何从每个输入 token 流向每个输出位置的技术,方法是递归地相乘 Transformer 每一层的注意力权重矩阵。该方法由 Abnar 和 Zuidema 于 2020 年提出,[1]用于解决一个众所周知的问题:单个层中原始的注意力权重并不能很好地代表每个输入 token 对给定预测的贡献,因为在每一层中,token 都通过 Self-Attention 反复与其邻居混合。
Rollout 将各层堆叠视为一个图,并计算从任意输出节点回溯到输入的路径加权可达性。所得的矩阵被广泛用作基于 transformer 的模型的显著性图,特别是 Vision Transformers(ViTs)和 BERT 风格的编码器,其中它是可解释性和无需重新训练即可产生类判别式定位的标准工具之一。
动机
单层注意力权重作为解释具有直观的吸引力:位置 $ i $ 以权重 $ a_{ij} $ "查看" 位置 $ j $。然而,在深层 transformer 中,第 $ \ell $ 层的表示已经是所有输入 token 的混合,因此后期某层中 $ a_{ij} $ 较大并不意味着位置 $ i $ 的预测主要依赖于原始输入 token $ j $。多项研究认为,"注意力不是解释",正是因为这种跨层污染。[2][3]
注意力 rollout 通过提出一个不同的问题来解决该问题:如果我们沿网络向后追踪信息,那么给定输出位置的表示中有多少源自每个输入 token?答案需要将逐层的注意力矩阵组合起来,而不是孤立地检查它们。
公式化
设置
考虑一个具有 $ L $ 层、序列长度为 $ n $ 且每层 $ H $ 个注意力头的 Transformer。设 $ A^{(\ell)} \in \mathbb{R}^{n \times n} $ 表示第 $ \ell $ 层的注意力权重矩阵,其中行对应于查询位置,列对应于键位置,且每行之和为 1。当该层有多个头时,先对每个头的矩阵 $ A^{(\ell, h)} $ 求平均:
$ {\displaystyle A^{(\ell)} = \frac{1}{H} \sum_{h=1}^{H} A^{(\ell, h)}.} $
有时也使用其他归约方式,例如对各头取最大值或学得的加权和。
残差修正
纯自注意力层并不是穿过 transformer 块的唯一路径:残差连接将该层的输入加到其输出上,因此每个位置上信号中有一个不可忽略的部分实际上就是该位置自身先前的表示。为了反映这一点,rollout 将原始的注意力矩阵替换为
$ {\displaystyle \tilde{A}^{(\ell)} = \frac{1}{2}\bigl(A^{(\ell)} + I\bigr),} $
其中 $ I $ 是单位矩阵。系数 1/2 保持了行随机性:$ \tilde{A}^{(\ell)} $ 的每一行之和仍为 1,因此它可以被解释为关于输入位置的概率分布。单位项编码了残差流的贡献。
递归乘积
第 $ \ell $ 层的 rollout 矩阵是累乘
$ {\displaystyle R^{(\ell)} = \tilde{A}^{(\ell)} \, \tilde{A}^{(\ell-1)} \cdots \tilde{A}^{(1)} = \prod_{k=\ell}^{1} \tilde{A}^{(k)}.} $
条目 $ R^{(L)}_{ij} $ 被解释为输出位置 $ i $ 的表示中,通过组合的注意力路径可归因于输入位置 $ j $ 的比例。由于每个 $ \tilde{A}^{(\ell)} $ 都是行随机的,因此 $ R^{(L)} $ 也是行随机的,从而可以直接将结果可视化为输入 token 上的热力图。
对于使用特殊 CLS token 的分类 ViT 或 BERT 风格模型,$ R^{(L)} $ 中对应于 CLS 索引的那一行给出了关于输入 patch 或 token、针对预测类别的显著性图。
算法
在任何能够暴露其注意力权重的 transformer 之上,该过程都很容易实现:
- 进行一次前向传播,并为每一层和每个头保存 $ A^{(\ell, h)} $。
- 对每一层,对各头取平均以获得 $ A^{(\ell)} $。
- 加上单位矩阵并重新归一化,得到 $ \tilde{A}^{(\ell)} $。
- 初始化 $ R \leftarrow I $,并从第一层到最后一层依次执行 $ R \leftarrow \tilde{A}^{(\ell)} R $。
- 取出感兴趣的那一行(通常是 CLS),并重塑为输入网格的形状。
总开销为 $ O(L n^2) $ 次矩阵-矩阵乘积,主要由逐层的乘法主导。对于图像分类中典型的较短序列长度,整个 rollout 相对于一次前向传播的开销非常小。
注意力流
Abnar 和 Zuidema 的原始论文还提出了一种相关但成本更高的变体,称为注意力流(attention flow)。它不再相乘逐层的矩阵,而是将注意力图视为一个带容量限制的有向图,并计算从每个输入节点到每个输出节点的最大流。流更忠实地反映了网络中的瓶颈,因为相乘会重复计入共享某条边的路径,但其代价在 $ n $ 上为超线性,因此很少在大规模场景中使用。在实践中,rollout 是主流变体。
应用
注意力 rollout 已成为多种模型家族的默认可视化工具:
- Vision Transformers:将 rollout 热力图限制在 $ R^{(L)} $ 的 CLS 行,并重塑为 patch 网格后,可产生类判别式定位结果,其性能与 Grad-CAM 等基于梯度的方法相当,[4]在 ImageNet 和弱监督分割基准上具有竞争力。[5]
- 语言模型:rollout 用于检查 BERT 或 RoBERTa 编码器在给定预测中依赖哪些输入 token,作为逐头探针研究的补充。
- 多模态 transformer:在 CLIP 风格和图文模型中,限制在交叉注意力层上的 rollout 矩阵可以揭示哪些图像区域支撑某个文本 token,从而支持开放词汇分割和定位。
- 模型审计:通过比较微调步骤或领域偏移前后的 rollout 图,从业者可以检测模型是否将其依赖从某个输入区域转移到了另一个区域。
变体与扩展
基本 rollout 的若干扩展可提高其忠实度或类别特异性:
- 梯度加权 rollout:将注意力矩阵与预测类别得分对注意力值的梯度相结合,在 rollout 乘积中将 $ A^{(\ell)} $ 替换为在相关类别下评估得到的 $ (\nabla_A y)\odot A^{(\ell)} $。Chefer 等人表明,这种方法在 ViTs 中比原始 rollout 产生更清晰、更具类判别性的图。[6]
- 按头分别计算的 rollout:不再对各头取平均,而是按头分别计算 rollout 乘积并对结果进行汇总,从而揭示各个头的作用。
- 编码器-解码器 rollout:对于翻译和seq2seq模型,rollout 沿编码器自注意力链、交叉注意力层和解码器自注意力链分别计算,并将矩阵端到端组合。
- 稀疏与剪枝 rollout:在相乘之前丢弃低于阈值的零项,以使可视化聚焦于主导路径。
- Top-k 注意力 rollout:在组合之前,对每个 $ \tilde{A}^{(\ell)} $ 的每一行只保留 top-k 项,其动机是观察到注意力分布通常具有重尾。
与其他显著性方法的比较
Rollout 是应用于 transformer 的若干解释工具家族之一:
- 与原始注意力相比,rollout 修正了多层混合带来的偏差,而正是这种混合使得单层注意力具有误导性。
- 与 Grad-CAM 和其他基于梯度的方法相比,rollout 完全基于前向传播,无需对预测进行求导;这使其计算成本低且与架构无关,但若不引入梯度,类别特异性也较弱。
- 与 积分梯度 和 SHAP 风格的方法相比,rollout 速度快得多(只需少量矩阵乘积,而非大量前向传播),但不满足相同的公理保证。
- 与注意力流相比,rollout 是一种易于计算的近似,在大多数情况下产生定性相似的显著性图。
局限性
将注意力 rollout 用作解释时存在若干注意事项:
- 未使用值路径:rollout 只检查注意力权重,而忽略实际承载内容的值投影。两个具有相同注意力模式但值矩阵差异很大的头会产生无法区分的 rollout。
- 单位修正只是启发式:对 $ A $ 和 $ I $ 等权重的处理只是一种约定,而非推导得出的量。真正的残差占比取决于残差和注意力贡献的相对范数,而这在不同层和不同 token 之间会有所差异。
- 类别无关:原始 rollout 给出的图与预测类别无关;需要梯度加权变体才能得到类判别式解释。
- 忠实度争议:对基于注意力的解释进行的形式化评估结果不一,rollout 也继承了这些批评。使用者应将 rollout 图视为提供信息的摘要,而非真值归因。
- 架构假设:标准形式的 rollout 假设存在一堆相同的、带有残差连接的注意力块。将其适配到专家混合、稀疏注意力或基于路由的 transformer 时需要谨慎。
参考文献
- ↑ Template:Cite arxiv
- ↑ Jain, S. and Wallace, B. C., "Attention is not Explanation", NAACL 2019.
- ↑ Wiegreffe, S. and Pinter, Y., "Attention is not not Explanation", EMNLP 2019.
- ↑ Selvaraju, R. R. et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization", ICCV 2017.
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv