Mixed Precision Training/zh
| Article | |
|---|---|
| Topic area | Deep Learning |
| Prerequisites | Neural Networks, Backpropagation, Stochastic Gradient Descent |
概述
混合精度訓練是一種加速深度神經網絡訓練的技術,它將大多數算術運算以較低精度的浮點格式(通常為 16 位)執行,同時將少量數值敏感的運算保留在較高精度(通常為 32 位)。該技術的現代形式由 Micikevicius 等人於 2017 年提出,已成為大規模深度學習的默認訓練方案,支撐着當代大多數關於卷積網絡、transformer和大型語言模型的工作。與純單精度(FP32)訓練相比,混合精度通常可將內存消耗減半,並在配備專用低精度矩陣單元的硬件上實現兩到八倍的吞吐量,同時基本達到相同的最終準確度。
該方法利用了這樣一個觀察:神經網絡訓練對大多數張量——激活值、梯度和中間矩陣乘法——的數值噪聲具有高度容忍性,但在少數關鍵位置需要高精度,特別是權重的主副本以及損失與優化器中的某些歸約運算。
浮點格式
三種浮點格式主導着現代深度學習。歷史上的基準是 FP32(IEEE 754 單精度),具有一個符號位、八個指數位和 23 個尾數位,動態範圍約為 $ 10^{-38} $ 到 $ 10^{38} $,約七位十進制精度。
FP16(IEEE 754 半精度)使用一個符號位、五個指數位和 10 個尾數位。其動態範圍要窄得多——約為 $ 6 \times 10^{-5} $ 到 $ 6.5 \times 10^4 $——這是混合精度訓練的核心數值挑戰。小梯度可能下溢為零,大激活值可能上溢為無窮大。
BF16(bfloat16)由 Google 為 TPU 引入,現已得到大多數現代加速器的支持。它保留了 FP32 的八個指數位,但將尾數截斷為七位。它具有與 FP32 相同的動態範圍,但精度遠低於 FP16,這使其作為 FP32 的直接替代品要容易得多,因為下溢和上溢都很少見。代價是每次單獨運算的捨入更粗糙。
更新的 FP8 格式系列(E4M3 和 E5M2,於 2022 年標準化)將相同的思路擴展到八位,主要用於超大型 transformer 訓練中的前向和反向矩陣乘法。FP8 通常需要按張量的縮放因子,並與更高精度的主格式一起使用。
混合精度方案
Micikevicius 等人的經典方案有三個組成部分。
FP32 主權重。優化器在 FP32 中維護一份模型參數的主副本。在每次前向傳遞之前,將此主副本下轉為低精度格式(FP16 或 BF16)以產生網絡中使用的工作權重。在優化器步驟之後,更新的是 FP32 主權重,而不是低精度副本。這可以防止 Stochastic Gradient Descent 產生的小參數更新在加到一個大得多的權重值上時由於捨入而丟失。
具體來說,若 $ w $ 是一個權重而 $ \Delta w $ 是其更新,則 FP16 在 $ w \approx 1 $ 附近的可表示間距約為 $ 2^{-10} \approx 10^{-3} $。小於此量級的更新——在訓練後期極為常見——如果加法在 FP16 中執行將完全丟失。
低精度前向和反向傳播。激活值、權重張量和梯度都以 FP16 或 BF16 存儲。矩陣乘法和卷積在專用的張量核或矩陣核上執行,這些核接受低精度輸入,在內部以 FP32 累加,然後寫回低精度輸出。這正是內存和吞吐量收益的來源。
損失縮放。由於 FP16 的動態範圍有限,小於約 $ 2^{-24} $ 的梯度值會下溢為零。解決方法是在反向傳播之前將損失乘以一個大的縮放因子 $ S $:
$ {\displaystyle L_{\mathrm{scaled}} = S \cdot L} $
根據鏈式法則,每個梯度都會被相同的因子 $ S $ 縮放,將小值提升到下溢區域之外。反向傳播之後,在優化器步驟之前,梯度在 FP32 中被反向縮放(除以 $ S $)。對於 BF16,損失縮放通常是不必要的,因為該格式繼承了 FP32 的指數範圍。
動態損失縮放
為 $ S $ 選擇單一的靜態值需要事先知道梯度分佈。現代框架轉而使用動態損失縮放,在訓練過程中調整 $ S $:
- 從一個較大的初始值開始(例如,$ S = 2^{16} $)。
- 在每次反向傳播之後,檢查是否有任何梯度包含無窮大或 NaN。
- 如果檢測到上溢,跳過該迭代的優化器步驟,並將 $ S $ 減半。
- 如果在固定數量的迭代(例如 2000)內沒有檢測到上溢,則將 $ S $ 加倍。
該程序在不因上溢而丟失迭代的前提下,將縮放保持在數值上儘可能大的水平,並隨着訓練過程中梯度大小的變化而自適應調整。
保留在 FP32 中的運算
在原本採用混合精度的計算圖中,仍有少數運算通常保留在 FP32 中。這些運算的數值行為對範圍或重複求和敏感:
- 注意力頭和分類頭中使用的 softmax 和 log-softmax,其中較大 logit 之間的微小差異很重要。
- 交叉熵損失計算,它將 softmax 與一個小數的對數結合起來。
- Batch Normalization 統計量——均值、方差以及運行估計——這些會在多個樣本上累加。
- 長軸上的歸約,例如用於裁剪的梯度範數。
- 優化器狀態(例如 Adam 的一階和二階矩估計),它們在多步上累積。
框架通過 autocast 區域或運算允許列表來體現這種區分:矩陣乘法和卷積會自動下轉,而列出的運算則保留在 FP32 中。
實現
PyTorch 通過 torch.cuda.amp(最初的 FP16 API)和 torch.amp(統一的 FP16 / BF16 API)提供混合精度,並配合 GradScaler 實現損失縮放。TensorFlow 通過 tf.keras.mixed_precision 策略表達同樣的思路。JAX 則使用顯式的 dtype 控制,並配合 Optax 等庫實現損失縮放。
NVIDIA 的 Apex 庫是第一個被廣泛使用的混合精度工具包,早於框架原生 API;作為動態損失縮放的來源,它在歷史上仍然重要。相關的 TF32 格式(由 Ampere 系列 GPU 隱式用於 FP32 矩陣乘法)有時被歸為混合精度,但技術上是一種獨立的優化,它將 FP32 輸入通過降精度乘法器處理。
與純低精度的比較
在沒有主權重或損失縮放的情況下進行純 FP16 訓練通常會發散或停滯,原因在於更新下溢和梯度下溢。在沒有主權重的情況下進行純 BF16 訓練對中等規模模型通常可行,但在長時間訓練中往往會損失最終準確度,特別是對於凸尾損失,因為七位尾數過於粗糙,無法精確累積 Adam 的小矩。混合精度通過將優化器狀態和主權重保留在 FP32 中來恢復這種準確度,同時仍能從低精度計算路徑中獲取大部分吞吐量收益。
局限性與失敗模式
混合精度並非毫無陷阱。最常見的失敗包括:
- 訓練初期的持續 NaN,通常由超過 FP16 最大值約 $ 6.5 \times 10^4 $ 的初始激活值引起。補救方法是使用 BF16、謹慎初始化或逐層梯度裁剪。
- 靜默的準確度損失:當本應保留在 FP32 中的運算(例如對超長序列的 softmax)被意外地以 FP16 執行時發生。標準修複方法是審計 autocast 策略。
- 損失縮放崩潰:動態縮放因子降至一併停留在那裏。這表明存在真正的數值問題而非調參問題,通常指向不良數據或不穩定的模型組件。
- 降低的可復現性:不同代際的張量核心可能為相同的 FP16 矩陣乘法產生略有不同的位級結果,這使得精確可復現性測試變得複雜。
對於超大型模型,FP8 引入了額外的考慮——必須跟蹤和更新逐張量的縮放因子——但高層結構與最初的 FP16 方案相同。
參見
參考文獻
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ Template:Cite arxiv
- ↑ NVIDIA, "Train With Mixed Precision," NVIDIA Deep Learning Performance Documentation, 2023.