Knowledge Distillation/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Cross-Entropy Loss, Softmax Function, KL Divergence |
概述
知识蒸馏是一种模型压缩与知识迁移技术,其中训练一个小型"学生"网络来模仿一个更大、更准确的"教师"模型的行为,而不是直接从原始标签中学习。学生的训练目标将标准的监督损失与一个使学生输出分布趋向于教师在相同输入上的输出分布的项相结合或予以替代。由于教师的输出所编码的信息比one-hot标签更丰富,包括教师的置信度以及它认为哪些替代类别是合理的,学生通常能够达到仅使用标签从头训练时无法获得的精度,而所需的计算与内存成本仅为一小部分。
该技术由 Hinton、Vinyals 和 Dean 于 2015 年在现代深度学习中推广,他们将其阐释为传递嵌入在教师软化 logits 中的"暗知识"。此后,它已成为生产深度学习流水线中的标准工具,部署于任何需要将强大但昂贵的模型替换为更便宜的推理模型的场景:从大型卷积集成中蒸馏出的移动视觉模型、从前沿基于 Transformer 的教师中蒸馏出的小型语言模型,以及从服务器级系统中蒸馏出的设备端语音识别器。除压缩之外,蒸馏还在训练流水线中用于自我改进、集成压缩、跨架构迁移,以及即使在教师与学生规模相同的情况下作为正则化器。
公式化
标准公式化考虑一个具有 $ K $ 个类别的分类任务。令 $ z^t = f^t(x) $ 和 $ z^s = f^s(x) $ 表示教师和学生在输入 $ x $ 上产生的 logits。Hinton 的关键手段是按温度缩放的 Softmax 函数:
$ {\displaystyle p_i^{\tau}(z) = \frac{\exp(z_i / \tau)}{\sum_{j=1}^{K} \exp(z_j / \tau)}.} $
温度 $ \tau > 1 $ 会平滑分布,提高非顶部类别的相对概率,并暴露教师对它们的相对信念。蒸馏损失使学生的软化分布与教师的相匹配:
$ {\displaystyle \mathcal{L}_{\text{KD}}(x) = \tau^2 \, D_{\mathrm{KL}}\!\left(p^{\tau}(z^t) \,\|\, p^{\tau}(z^s)\right),} $
其中因子 $ \tau^2 $ 补偿了将 logits 除以 $ \tau $ 所引入的梯度缩放,使得蒸馏梯度的幅度在不同温度下保持可比。总目标通常是与标准硬标签 交叉熵损失的凸组合:
$ {\displaystyle \mathcal{L}(x, y) = (1 - \alpha) \, \mathcal{L}_{\text{CE}}(y, p^{1}(z^s)) + \alpha \, \mathcal{L}_{\text{KD}}(x),} $
其中 $ y $ 是真实标签,$ \alpha \in [0, 1] $ 在两项之间进行权衡,交叉熵在温度 $ 1 $ 下评估,以使硬标签监督不被软化。典型的超参数为 $ \tau \in [2, 10] $ 和 $ \alpha \in [0.5, 0.9] $,其值在验证集上调优。
在高温极限下,对软化的 softmax 进行展开表明,最小化 KL 散度退化为在每个样本均值意义下匹配教师的 logits,这导致了 Bucila、Caruana 和 Niculescu-Mizil 提出的较早的 logit 匹配变体。在温度 $ 1 $ 下,蒸馏项退化为对教师预测分布的普通交叉熵,恢复为"软标签训练"。
为什么有效:暗知识
Hinton 强调的直觉是:自信的教师对错误类别赋予的近零概率仍然携带信息。一个在 ImageNet 上训练的模型在真实标签为"垃圾车"时,可能为"宝马"分配 $ 10^{-6} $ 的概率,为"胡萝卜"分配 $ 10^{-9} $ 的概率,而这些极小概率之间的比例编码了宝马比胡萝卜更像卡车。One-hot 标签会破坏这种相似性结构;教师的软化分布则保留了它。因此,训练学生重现完整分布传达了关于标签空间几何的归纳偏置,这是任何带标签示例本身都无法提供的。
一种互补的观点是,教师充当贝叶斯最优类别后验的平滑估计器。当标签是随机的或模糊的时,教师的分布会对可能的答案进行平均,为学生提供比标签本身噪声更小的训练信号。从这个角度看,蒸馏是与 标签平滑密切相关的一种 正则化形式:两者都用更软的目标替代了 one-hot 目标,但蒸馏的目标依赖于输入而非均匀分布。蒸馏的有效正则化强度已在相关工作中得到形式化,这些工作表明,当教师是校准良好的估计器时,它近似等价于对 偏差-方差权衡的特定调整。
变体
Hinton 风格的软目标损失现在通常称为响应蒸馏或 logit 蒸馏,因为监督位于网络输出。第二类,特征蒸馏,则匹配中间表示:要求学生重现教师的隐藏激活或注意力图,可能通过一个学习到的投影来实现。FitNets、注意力迁移以及更近期的特征模仿损失均属于此类。当仅靠输出监督不足时,特征蒸馏可以从教师中提取更多指导,特别是当架构差异较大、对齐输出过于粗糙时。
第三类,关系蒸馏,传递的是关于教师如何组织一批示例的结构信息,而非其绝对预测。Relational KD 和 Similarity-Preserving KD 等方法匹配激活的 Gram 矩阵或嵌入之间的成对距离,这使得监督对两个网络的精确特征维度保持不变。
蒸馏还根据教师与学生的训练时机进行细分。离线蒸馏使用一个固定的预训练教师;这是迄今为止最常见的设置。在线蒸馏共同训练一组学生,其中每个学生将其他学生的聚合视为软教师,从而无需单独训练的教师。自蒸馏在单一架构上迭代,一轮的学生成为下一轮的教师,并且令人惊讶的是,即使架构保持不变,通常也能提升精度。Born-again 网络形式化了这种迭代的自蒸馏过程。
对于语言模型而言,Kim 和 Rush 的序列级蒸馏将该技术适配于自回归生成:训练学生模仿教师的束搜索输出而不是其每 token 分布,从而避免了曝光偏差不匹配,并被广泛用于压缩翻译与摘要模型。对于非常大的模型,蒸馏支撑了许多实用的小型 LM 配方,包括 DistilBERT、MobileBERT,以及从前沿教师生产推理便宜变体的更广泛实践。
训练与推理
标准的离线蒸馏流水线运行如下。教师被训练或下载并保持冻结。训练循环遍历带标签的训练集,并对每个批次同时运行教师(处于评估模式)和学生。如果存储允许,则预先计算教师的软化概率;否则即时计算;存储教师 logits 避免了跨周期的冗余教师前向传播,但对于 $ N $ 个训练示例需要消耗 $ O(N K) $ 的额外内存。学生通过将组合损失反向传播到自身参数中进行更新;教师永远不会被更新。
蒸馏可以在教师所见的相同数据上运行,可以在额外的无标签数据上运行(因为软标签不需要真实标签),也可以在保留的迁移集上运行。无标签数据设置在生产中尤其有吸引力:可以通过依赖教师提供目标来将学生的训练集扩展到远超带标签语料库,这本质上就是从前沿教师生产现代小型语言模型的方式。
在推理时,教师被完全丢弃。学生作为独立模型运行,没有由蒸馏过程引入的任何架构开销。
比较
蒸馏是三种主要模型压缩策略之一,与 量化和 剪枝并列。量化降低固定架构的数值精度;剪枝从固定架构中移除权重或结构;蒸馏完全改变架构,通常用更浅或更窄的网络替换深而宽的网络。这三种方法在很大程度上是互补的,并且经常被组合使用:将前沿教师蒸馏到较小的架构中,然后将其剪枝并量化以进行部署。当原始教师远大于部署预算所允许的规模时,蒸馏单独往往在固定大小下提供最大的精度增益;而当架构已接近正确大小时,量化和剪枝则提供更好的增益。
蒸馏在压缩之外也有密切的概念近亲。共蒸馏和在线蒸馏是多个学生互相教学的 集成方法训练形式;半监督学习中的 mean-teacher 方法是自蒸馏的滑动平均形式;强化学习中的策略蒸馏使用应用于动作分布而非类别概率的相同机制,将复杂策略转移到更简单的策略。
局限性
蒸馏并非免费。它需要一个工作的教师,而教师本身必须以某种代价被训练,并且学生的精度的上界是所选学生架构原则上可以表示的范围:缺乏建模任务容量的网络不会被更软的目标拯救。温度和损失权重的选择是经验性的,病态组合(例如非常高的温度配以非常低的 $ \alpha $)可能产生一个比起成功更忠实地模仿教师错误的学生。当教师校准不良时,软标签可能会主动损害学生;从已记住其训练集的教师中蒸馏会将这种记忆传播到学生中。
特征级蒸馏引入了额外的脆弱性:教师与学生特征之间的投影对齐本身就是一个超参数,过于激进的特征匹配可能使学生过度受限于教师的表示特性。对于生成模型,蒸馏以 Hinton 的响应级损失无法解决的方式与自回归训练的 曝光偏差相互作用,这促成了序列级变体的提出。最后,蒸馏不对迁移集所覆盖的分布之外的行为提供保证;蒸馏出的学生可能在教师从未被查询过的输入空间区域悄然失败,这对于安全关键部署以及对 大型语言模型的蒸馏(其教师在训练时被查询过广阔的输入空间但在蒸馏时只接触到狭窄的一部分)尤其令人担忧。
参考文献
[1] [2] [3] [4] [5] [6] [7] [8] [9]
- ↑ Template:Cite arxiv
- ↑ Bucila, C., Caruana, R., and Niculescu-Mizil, A. Model Compression. KDD, 2006.
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv