Mixed Precision Training/zh

    From Marovi AI
    This page is a translated version of the page Mixed Precision Training and the translation is 100% complete.
    Other languages:
    Article
    Topic area Deep Learning
    Prerequisites Neural Networks, Backpropagation, Stochastic Gradient Descent


    概述

    混合精度训练是一种加速深度神经网络训练的技术,它将大多数算术运算以较低精度的浮点格式(通常为 16 位)执行,同时将少量数值敏感的运算保留在较高精度(通常为 32 位)。该技术的现代形式由 Micikevicius 等人于 2017 年提出,已成为大规模深度学习的默认训练方案,支撑着当代大多数关于卷积网络transformer和大型语言模型的工作。与纯单精度(FP32)训练相比,混合精度通常可将内存消耗减半,并在配备专用低精度矩阵单元的硬件上实现两到八倍的吞吐量,同时基本达到相同的最终准确度

    该方法利用了这样一个观察:神经网络训练对大多数张量——激活值、梯度和中间矩阵乘法——的数值噪声具有高度容忍性,但在少数关键位置需要高精度,特别是权重的主副本以及损失与优化器中的某些归约运算。

    浮点格式

    三种浮点格式主导着现代深度学习。历史上的基准是 FP32(IEEE 754 单精度),具有一个符号位、八个指数位和 23 个尾数位,动态范围约为 $ 10^{-38} $$ 10^{38} $,约七位十进制精度。

    FP16(IEEE 754 半精度)使用一个符号位、五个指数位和 10 个尾数位。其动态范围要窄得多——约为 $ 6 \times 10^{-5} $$ 6.5 \times 10^4 $——这是混合精度训练的核心数值挑战。小梯度可能下溢为零,大激活值可能上溢为无穷大。

    BF16(bfloat16)由 Google 为 TPU 引入,现已得到大多数现代加速器的支持。它保留了 FP32 的八个指数位,但将尾数截断为七位。它具有与 FP32 相同的动态范围,但精度远低于 FP16,这使其作为 FP32 的直接替代品要容易得多,因为下溢和上溢都很少见。代价是每次单独运算的舍入更粗糙。

    更新的 FP8 格式系列(E4M3 和 E5M2,于 2022 年标准化)将相同的思路扩展到八位,主要用于超大型 transformer 训练中的前向和反向矩阵乘法。FP8 通常需要按张量的缩放因子,并与更高精度的主格式一起使用。

    混合精度方案

    Micikevicius 等人的经典方案有三个组成部分。

    FP32 主权重。优化器在 FP32 中维护一份模型参数的主副本。在每次前向传递之前,将此主副本下转为低精度格式(FP16BF16)以产生网络中使用的工作权重。在优化器步骤之后,更新的是 FP32 主权重,而不是低精度副本。这可以防止 Stochastic Gradient Descent 产生的小参数更新在加到一个大得多的权重值上时由于舍入而丢失。

    具体来说,若 $ w $ 是一个权重而 $ \Delta w $ 是其更新,则 FP16 在 $ w \approx 1 $ 附近的可表示间距约为 $ 2^{-10} \approx 10^{-3} $。小于此量级的更新——在训练后期极为常见——如果加法在 FP16 中执行将完全丢失。

    低精度前向和反向传播。激活值、权重张量和梯度都以 FP16 或 BF16 存储。矩阵乘法卷积在专用的张量核或矩阵核上执行,这些核接受低精度输入,在内部以 FP32 累加,然后写回低精度输出。这正是内存和吞吐量收益的来源。

    损失缩放。由于 FP16 的动态范围有限,小于约 $ 2^{-24} $ 的梯度值会下溢为零。解决方法是在反向传播之前将损失乘以一个大的缩放因子 $ S $

    $ {\displaystyle L_{\mathrm{scaled}} = S \cdot L} $

    根据链式法则,每个梯度都会被相同的因子 $ S $ 缩放,将小值提升到下溢区域之外。反向传播之后,在优化器步骤之前,梯度在 FP32 中被反向缩放(除以 $ S $)。对于 BF16,损失缩放通常是不必要的,因为该格式继承了 FP32 的指数范围。

    动态损失缩放

    $ S $ 选择单一的静态值需要事先知道梯度分布。现代框架转而使用动态损失缩放,在训练过程中调整 $ S $

    • 从一个较大的初始值开始(例如,$ S = 2^{16} $)。
    • 在每次反向传播之后,检查是否有任何梯度包含无穷大或 NaN。
    • 如果检测到上溢,跳过该迭代的优化器步骤,并将 $ S $ 减半。
    • 如果在固定数量的迭代(例如 2000)内没有检测到上溢,则将 $ S $ 加倍。

    该程序在不因上溢而丢失迭代的前提下,将缩放保持在数值上尽可能大的水平,并随着训练过程中梯度大小的变化而自适应调整。

    保留在 FP32 中的运算

    在原本采用混合精度的计算图中,仍有少数运算通常保留在 FP32 中。这些运算的数值行为对范围或重复求和敏感:

    • 注意力头和分类头中使用的 softmax 和 log-softmax,其中较大 logit 之间的微小差异很重要。
    • 交叉熵损失计算,它将 softmax 与一个小数的对数结合起来。
    • Batch Normalization 统计量——均值、方差以及运行估计——这些会在多个样本上累加。
    • 长轴上的归约,例如用于裁剪的梯度范数。
    • 优化器状态(例如 Adam 的一阶和二阶矩估计),它们在多步上累积。

    框架通过 autocast 区域或运算允许列表来体现这种区分:矩阵乘法卷积会自动下转,而列出的运算则保留在 FP32 中。

    实现

    PyTorch 通过 torch.cuda.amp(最初的 FP16 API)和 torch.amp(统一的 FP16 / BF16 API)提供混合精度,并配合 GradScaler 实现损失缩放。TensorFlow 通过 tf.keras.mixed_precision 策略表达同样的思路。JAX 则使用显式的 dtype 控制,并配合 Optax 等库实现损失缩放。

    NVIDIA 的 Apex 库是第一个被广泛使用的混合精度工具包,早于框架原生 API;作为动态损失缩放的来源,它在历史上仍然重要。相关的 TF32 格式(由 Ampere 系列 GPU 隐式用于 FP32 矩阵乘法)有时被归为混合精度,但技术上是一种独立的优化,它将 FP32 输入通过降精度乘法器处理。

    与纯低精度的比较

    在没有主权重或损失缩放的情况下进行纯 FP16 训练通常会发散或停滞,原因在于更新下溢和梯度下溢。在没有主权重的情况下进行纯 BF16 训练对中等规模模型通常可行,但在长时间训练中往往会损失最终准确度,特别是对于凸尾损失,因为七位尾数过于粗糙,无法精确累积 Adam 的小矩。混合精度通过将优化器状态和主权重保留在 FP32 中来恢复这种准确度,同时仍能从低精度计算路径中获取大部分吞吐量收益。

    局限性与失败模式

    混合精度并非毫无陷阱。最常见的失败包括:

    • 训练初期的持续 NaN,通常由超过 FP16 最大值约 $ 6.5 \times 10^4 $ 的初始激活值引起。补救方法是使用 BF16、谨慎初始化或逐层梯度裁剪
    • 静默的准确度损失:当本应保留在 FP32 中的运算(例如对超长序列的 softmax)被意外地以 FP16 执行时发生。标准修复方法是审计 autocast 策略。
    • 损失缩放崩溃:动态缩放因子降至一并停留在那里。这表明存在真正的数值问题而非调参问题,通常指向不良数据或不稳定的模型组件。
    • 降低的可复现性:不同代际的张量核心可能为相同的 FP16 矩阵乘法产生略有不同的位级结果,这使得精确可复现性测试变得复杂。

    对于超大型模型,FP8 引入了额外的考虑——必须跟踪和更新逐张量的缩放因子——但高层结构与最初的 FP16 方案相同。

    参见

    参考文献

    [1] [2] [3] [4]

    1. Template:Cite arxiv
    2. Template:Cite arxiv
    3. Template:Cite arxiv
    4. NVIDIA, "Train With Mixed Precision," NVIDIA Deep Learning Performance Documentation, 2023.