Knowledge Distillation/zh

    From Marovi AI
    This page is a translated version of the page Knowledge Distillation and the translation is 100% complete.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Cross-Entropy Loss, Softmax Function, KL Divergence


    概述

    知識蒸餾是一種模型壓縮與知識遷移技術,其中訓練一個小型"學生"網絡來模仿一個更大、更準確的"教師"模型的行為,而不是直接從原始標籤中學習。學生的訓練目標將標準的監督損失與一個使學生輸出分佈趨向於教師在相同輸入上的輸出分佈的項相結合或予以替代。由於教師的輸出所編碼的信息比one-hot標籤更豐富,包括教師的置信度以及它認為哪些替代類別是合理的,學生通常能夠達到僅使用標籤從頭訓練時無法獲得的精度,而所需的計算與內存成本僅為一小部分。

    該技術由 Hinton、Vinyals 和 Dean 於 2015 年在現代深度學習中推廣,他們將其闡釋為傳遞嵌入在教師軟化 logits 中的"暗知識"。此後,它已成為生產深度學習流水線中的標準工具,部署於任何需要將強大但昂貴的模型替換為更便宜的推理模型的場景:從大型卷積集成中蒸餾出的移動視覺模型、從前沿基於 Transformer 的教師中蒸餾出的小型語言模型,以及從伺服器級系統中蒸餾出的設備端語音識別器。除壓縮之外,蒸餾還在訓練流水線中用於自我改進、集成壓縮、跨架構遷移,以及即使在教師與學生規模相同的情況下作為正則化器

    公式化

    標準公式化考慮一個具有 $ K $ 個類別的分類任務。令 $ z^t = f^t(x) $$ z^s = f^s(x) $ 表示教師和學生在輸入 $ x $ 上產生的 logits。Hinton 的關鍵手段是按溫度縮放的 Softmax 函數

    $ {\displaystyle p_i^{\tau}(z) = \frac{\exp(z_i / \tau)}{\sum_{j=1}^{K} \exp(z_j / \tau)}.} $

    溫度 $ \tau > 1 $ 會平滑分佈,提高非頂部類別的相對概率,並暴露教師對它們的相對信念。蒸餾損失使學生的軟化分佈與教師的相匹配:

    $ {\displaystyle \mathcal{L}_{\text{KD}}(x) = \tau^2 \, D_{\mathrm{KL}}\!\left(p^{\tau}(z^t) \,\|\, p^{\tau}(z^s)\right),} $

    其中因子 $ \tau^2 $ 補償了將 logits 除以 $ \tau $ 所引入的梯度縮放,使得蒸餾梯度的幅度在不同溫度下保持可比。總目標通常是與標準硬標籤 交叉熵損失的凸組合:

    $ {\displaystyle \mathcal{L}(x, y) = (1 - \alpha) \, \mathcal{L}_{\text{CE}}(y, p^{1}(z^s)) + \alpha \, \mathcal{L}_{\text{KD}}(x),} $

    其中 $ y $ 是真實標籤,$ \alpha \in [0, 1] $ 在兩項之間進行權衡,交叉在溫度 $ 1 $ 下評估,以使硬標籤監督不被軟化。典型的超參數為 $ \tau \in [2, 10] $$ \alpha \in [0.5, 0.9] $,其值在驗證集上調優。

    在高溫極限下,對軟化的 softmax 進行展開表明,最小化 KL 散度退化為在每個樣本均值意義下匹配教師的 logits,這導致了 Bucila、Caruana 和 Niculescu-Mizil 提出的較早的 logit 匹配變體。在溫度 $ 1 $ 下,蒸餾項退化為對教師預測分佈的普通交叉熵,恢復為"軟標籤訓練"。

    為什麼有效:暗知識

    Hinton 強調的直覺是:自信的教師對錯誤類別賦予的近零概率仍然攜帶信息。一個在 ImageNet 上訓練的模型在真實標籤為"垃圾車"時,可能為"寶馬"分配 $ 10^{-6} $ 的概率,為"胡蘿蔔"分配 $ 10^{-9} $ 的概率,而這些極小概率之間的比例編碼了寶馬比胡蘿蔔更像卡車。One-hot 標籤會破壞這種相似性結構;教師的軟化分佈則保留了它。因此,訓練學生重現完整分佈傳達了關於標籤空間幾何的歸納偏置,這是任何帶標籤示例本身都無法提供的。

    一種互補的觀點是,教師充當貝葉斯最優類別後驗的平滑估計器。當標籤是隨機的或模糊的時,教師的分佈會對可能的答案進行平均,為學生提供比標籤本身噪聲更小的訓練信號。從這個角度看,蒸餾是與 標籤平滑密切相關的一種 正則化形式:兩者都用更軟的目標替代了 one-hot 目標,但蒸餾的目標依賴於輸入而非均勻分佈。蒸餾的有效正則化強度已在相關工作中得到形式化,這些工作表明,當教師是校準良好的估計器時,它近似等價於對 偏差-方差權衡的特定調整。

    變體

    Hinton 風格的軟目標損失現在通常稱為響應蒸餾或 logit 蒸餾,因為監督位於網絡輸出。第二類,特徵蒸餾,則匹配中間表示:要求學生重現教師的隱藏激活或注意力,可能通過一個學習到的投影來實現。FitNets、注意力遷移以及更近期的特徵模仿損失均屬於此類。當僅靠輸出監督不足時,特徵蒸餾可以從教師中提取更多指導,特別是當架構差異較大、對齊輸出過於粗糙時。

    第三類,關係蒸餾,傳遞的是關於教師如何組織一批示例的結構信息,而非其絕對預測。Relational KD 和 Similarity-Preserving KD 等方法匹配激活的 Gram 矩陣或嵌入之間的成對距離,這使得監督對兩個網絡的精確特徵維度保持不變。

    蒸餾還根據教師與學生的訓練時機進行細分。離線蒸餾使用一個固定的預訓練教師;這是迄今為止最常見的設置。在線蒸餾共同訓練一組學生,其中每個學生將其他學生的聚合視為軟教師,從而無需單獨訓練的教師。自蒸餾在單一架構上迭代,一輪的學生成為下一輪的教師,並且令人驚訝的是,即使架構保持不變,通常也能提升精度。Born-again 網絡形式化了這種迭代的自蒸餾過程。

    對於語言模型而言,Kim 和 Rush 的序列級蒸餾將該技術適配於自回歸生成:訓練學生模仿教師的束搜索輸出而不是其每 token 分佈,從而避免了曝光偏差不匹配,並被廣泛用於壓縮翻譯與摘要模型。對於非常大的模型,蒸餾支撐了許多實用的小型 LM 配方,包括 DistilBERT、MobileBERT,以及從前沿教師生產推理便宜變體的更廣泛實踐。

    訓練與推理

    標準的離線蒸餾流水線運行如下。教師被訓練或下載並保持凍結。訓練循環遍歷帶標籤的訓練集,並對每個批次同時運行教師(處於評估模式)和學生。如果存儲允許,則預先計算教師的軟化概率;否則即時計算;存儲教師 logits 避免了跨周期的冗餘教師前向傳播,但對於 $ N $ 個訓練示例需要消耗 $ O(N K) $ 的額外內存。學生通過將組合損失反向傳播到自身參數中進行更新;教師永遠不會被更新。

    蒸餾可以在教師所見的相同數據上運行,可以在額外的無標籤數據上運行(因為軟標籤不需要真實標籤),也可以在保留的遷移集上運行。無標籤數據設置在生產中尤其有吸引力:可以通過依賴教師提供目標來將學生的訓練集擴展到遠超帶標籤語料庫,這本質上就是從前沿教師生產現代小型語言模型的方式。

    在推理時,教師被完全丟棄。學生作為獨立模型運行,沒有由蒸餾過程引入的任何架構開銷。

    比較

    蒸餾是三種主要模型壓縮策略之一,與 量化剪枝並列。量化降低固定架構的數值精度;剪枝從固定架構中移除權重或結構;蒸餾完全改變架構,通常用更淺或更窄的網絡替換深而寬的網絡。這三種方法在很大程度上是互補的,並且經常被組合使用:將前沿教師蒸餾到較小的架構中,然後將其剪枝並量化以進行部署。當原始教師遠大於部署預算所允許的規模時,蒸餾單獨往往在固定大小下提供最大的精度增益;而當架構已接近正確大小時,量化和剪枝則提供更好的增益。

    蒸餾在壓縮之外也有密切的概念近親。共蒸餾和在線蒸餾是多個學生互相教學的 集成方法訓練形式;半監督學習中的 mean-teacher 方法是自蒸餾的滑動平均形式;強化學習中的策略蒸餾使用應用於動作分佈而非類別概率的相同機制,將複雜策略轉移到更簡單的策略。

    局限性

    蒸餾並非免費。它需要一個工作的教師,而教師本身必須以某種代價被訓練,並且學生的精度的上界是所選學生架構原則上可以表示的範圍:缺乏建模任務容量的網絡不會被更軟的目標拯救。溫度和損失權重的選擇是經驗性的,病態組合(例如非常高的溫度配以非常低的 $ \alpha $)可能產生一個比起成功更忠實地模仿教師錯誤的學生。當教師校準不良時,軟標籤可能會主動損害學生;從已記住其訓練集的教師中蒸餾會將這種記憶傳播到學生中。

    特徵級蒸餾引入了額外的脆弱性:教師與學生特徵之間的投影對齊本身就是一個超參數,過於激進的特徵匹配可能使學生過度受限於教師的表示特性。對於生成模型,蒸餾以 Hinton 的響應級損失無法解決的方式與自回歸訓練的 曝光偏差相互作用,這促成了序列級變體的提出。最後,蒸餾不對遷移集所覆蓋的分佈之外的行為提供保證;蒸餾出的學生可能在教師從未被查詢過的輸入空間區域悄然失敗,這對於安全關鍵部署以及對 大型語言模型的蒸餾(其教師在訓練時被查詢過廣闊的輸入空間但在蒸餾時只接觸到狹窄的一部分)尤其令人擔憂。

    參考文獻

    [1] [2] [3] [4] [5] [6] [7] [8] [9]