Cross-Attention/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Attention Mechanism, Self-Attention, Transformer |
概述
交叉注意力是注意力机制的一种变体,其中查询来自一个序列,而键和值来自另一个序列。它是让模型根据另一个流的内容来调控某个流的生成或表示的标准机制,也是编码器-解码器型Transformer、检索增强模型以及大多数将文本与图像、音频或视频绑定的现代多模态系统的架构骨干。
与自注意力不同——后者中查询、键和值都是同一输入的投影——交叉注意力建立了一种从源序列到目标序列的非对称信息流。目标序列提出问题;源序列提供证据。这种解耦使得交叉注意力在两个流的长度不同、模态不同或角色不同时尤为自然契合:例如,在机器翻译中解码器的词元关注编码器的输出,或者在视觉-语言模型中文本词元关注图像块。
交叉注意力作为原始Transformer架构的一部分被提出[1],此后已成为在扩散模型[2]、Perceiver 风格架构[3]以及当代多模态大型语言模型中被反复使用的基本构件。
直觉
一个有用的思维模型是把注意力看作在一个关联记忆中进行的软性、可微分的查找。自注意力是一种其内容就是执行查找的词元本身的记忆;交叉注意力则是一种内容来自别处的记忆。翻译模型的解码器在生成下一个德语单词时,会查询编码器对源英语句子的表示,以决定此刻应该关注哪些词。查询知道自己想要什么("一个表示主语的名词短语");键宣告它们能提供什么("我是一个关于猫的名词短语");值则提供实际混合到解码器隐藏状态中的内容。
由于在一个解码步内键和值来自一个固定的外部源,交叉注意力也是将条件信号注入生成模型的最自然的位置。例如,文本到图像的扩散模型将去噪U-Net的空间特征作为查询,将编码后的文本提示作为键和值,使得每个空间位置都能在每个去噪步中有选择地从提示中提取语义内容。
公式化
设目标(查询)序列长度为$ n $、隐藏维度为$ d $,即$ X_{\text{tgt}} \in \mathbb{R}^{n \times d} $;设源(键/值)序列长度为$ m $,即$ X_{\text{src}} \in \mathbb{R}^{m \times d} $。三个可学习的线性投影分别产生查询、键和值:
$ {\displaystyle Q = X_{\text{tgt}} W_Q, \quad K = X_{\text{src}} W_K, \quad V = X_{\text{src}} W_V} $
其中$ W_Q, W_K \in \mathbb{R}^{d \times d_k} $,$ W_V \in \mathbb{R}^{d \times d_v} $。缩放点积交叉注意力随后计算
$ {\displaystyle \operatorname{CrossAttn}(X_{\text{tgt}}, X_{\text{src}}) = \operatorname{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \in \mathbb{R}^{n \times d_v}.} $
关键的结构性事实是两个操作数的非对称性:输出与目标长度相同,但其内容是从源中抽取的值向量的凸组合。注意力矩阵是形状为$ n \times m $的矩形矩阵,而不是像自注意力中那样是方阵。
在实践中,交叉注意力几乎总是采用其多头形式。设有$ h $个维度为$ d_k = d / h $的头,查询、键和值被分割,每个头分别应用注意力,然后将各头的输出拼接并进行线性投影:
$ {\displaystyle \operatorname{MultiHead}(X_{\text{tgt}}, X_{\text{src}}) = \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)\, W_O.} $
不同的头可以各自专门化:有些按位置进行对齐,有些进行语义对齐,还有些则起到近似均匀平滑的作用。
在编码器-解码器 Transformer 中的应用
在原始Transformer中,每个解码器块包含三个子层:对部分已生成目标进行的掩码自注意力、对编码器栈最终输出进行的交叉注意力,以及按位置施加的前馈网络。交叉注意力子层是信息从源流向目标的唯一通道;将其移除后,解码器就退化为一个看不到源句子的普通语言模型。
由此可以得出几个实际的要点。由于编码器表示只计算一次并在每个解码步中复用,键和值可以跨步缓存,这使得交叉注意力比随部分目标增长的自注意力便宜得多。大多数生产环境中的解码器会为自注意力维护一个单独的 KV 缓存(按词元逐步增长),并为交叉注意力维护一个静态、预先计算好的 KV 张量。
交叉注意力也是源端填充掩码的应用位置:源序列中被填充的位置会被掩蔽,使得softmax对其分配零概率。相比之下,因果掩码在交叉注意力中并不必要——解码器在任何解码步都可以关注任何源位置。
变体
若干变体在基本交叉注意力层之上进行扩展或修改,以应对特定的约束。
带门控的交叉注意力在交叉注意力的输出上引入一个可学习的、通常初始化为零的门,使得新增的交叉注意力层不会破坏预训练模型的稳定性。这是 Flamingo 用来将视觉上下文嫁接到一个冻结的语言模型上的机制[4],更普遍地,它也是参数高效多模态适配的一种常见模式。
Perceiver 风格的交叉注意力使用一小组可学习的潜在向量作为查询,针对一段非常长的输入序列进行查询,将输入压缩为一个与其长度无关的固定大小表示。这打破了标准自注意力对输入长度的二次依赖,也正是 Perceiver 系列能够在不依赖特定模态分词器的前提下处理原始像素、音频采样和点云的原因。
扩散模型中的交叉注意力通过将网络的空间特征图作为查询、将条件嵌入作为键和值,来基于文本或类别嵌入对去噪网络进行条件化。同一机制在每一层、每一步去噪中反复应用,正是这一点赋予潜在扩散模型对生成图像的细粒度可控性。
记忆与检索式交叉注意力将源序列推广为一个被检索到的片段数据库。诸如 RETRO 以及 kNN 增强Transformer之类的架构会检索最近邻段落并对其进行交叉注意力,从而把模型的参数容量与其在推理时可访问的知识解耦开来。
交叉注意力与自注意力的比较
交叉注意力与自注意力之间的差异是结构性的而非算法性的:计算的是相同的缩放点积,但键和值来自不同的源。由此带来了若干实际后果。
注意力矩阵是矩形的,通常不是方阵,因此其代价为$ O(nm) $而非$ O(n^2) $;对于关注长源的短目标,这比在两者拼接序列上做自注意力要便宜得多。填充掩码仅作用于源端;在使用因果掩码时,它作用于目标自身的自注意力,而非交叉注意力。由于源表示在解码过程中是固定的,其键和值可以预先计算一次并复用,这在推理时是一项可观的收益。
一个更细微的要点是:如果源表示已经包含了来自先前编码器的位置信息,那么交叉注意力在源端就不需要再添加位置编码。在多模态场景下,源模态和目标模态的位置结构差异很大(例如,源是二维图像块,目标是一维文本),位置信息通常驻留在编码器内部,而不是在交叉注意力的边界处添加。
局限性
交叉注意力沿用了标准注意力在$ n \times m $矩形上的二次内存开销。当源非常长时——例如长文档、高分辨率图像或时长达数小时的音频——注意力矩阵会成为主导开销,因此需要各种稀疏、低秩或内存高效的近似方法[5]。
交叉注意力对源流和目标流之间的分布偏移也是出了名的脆弱。一个被训练为关注干净编码器输出的解码器,在编码器被替换或微调后可能会显著退化,因为键的几何结构可能以查询未曾预料的方式发生变化。联合训练、门控或精心设计的适配器通常可以缓解这一问题。
最后,交叉注意力本身并不能解决落地(grounding)或幻觉问题。该机制仅规定信息如何流动,并不强制目标忠实地反映源。使用交叉注意力训练的模型完全可能——并且确实会——忽略其条件,尤其是在自回归设定下,目标自身的自注意力可能压过交叉注意力信号。