RMSNorm/zh

    From Marovi AI
    This page is a translated version of the page RMSNorm and the translation is 77% complete.
    Other languages:
    Article
    Topic area neural-network-components
    Prerequisites Layer Normalization, Transformer


    概述

    均方根层归一化RMSNorm)是一种用于深度神经网络的归一化层,由 Zhang 和 Sennrich 于 2019 年提出,作为 Layer Normalization 的简化版本。RMSNorm 通过输入vector的均方根(RMS)对其进行重新缩放,并对每个特征应用一个可学习的增益,但与 LayerNorm 不同的是,它省略了均值中心化(重新中心化)步骤以及加性bias。其结果是一个算术开销约为 LayerNorm 一半的层,同时保留了 LayerNorm 的大部分训练稳定性优势。

    RMSNorm 已成为现代 Transformer 架构(尤其是大型 Language Model)的标准组件。它是 T5、LLaMA、LLaMA 2、LLaMA 3、Mistral、Gemma、Qwen 以及许多其他开放权重模型所使用的归一化层。其流行stems于若干因素的结合:实现简单、计算速度略快、与 Pre-LayerNorm 布局配合时训练稳定性略有提升,以及一项经验观察——对activations进行重新中心化几乎不会改变下游任务的质量。

    动机

    Layer Normalization 通过减去均值并除以标准差来标准化一个vector $ x \in \mathbb{R}^d $,然后应用一个可学习的缩放因子 $ \gamma $ 和偏移 $ \beta $。重新中心化步骤(减去均值)最初的动机来自与 Batch Normalization 的类比,后者沿批次维度对每个特征进行中心化。Zhang 和 Sennrich 观察到,LayerNorm 在深层网络中的实际收益几乎全部来自重新缩放而非重新中心化:该层的作用是将activation的幅值保持在有界范围内,从而使梯度既不消失也不爆炸。均值减法步骤需要额外遍历向量一次,引入一个额外参数($ \beta $),却对优化几乎没有贡献。

    RMSNorm 回答的正是这样一个问题:"如果我们只保留重新缩放部分会怎样?" 它对输入的逐向量重新缩放保持不变,但与 LayerNorm 不同,它对常数偏移并不保持不变。经验表明,在现代 Transformer 的训练中,这一缺失的不变性并未带来任何可测量的代价。

    公式

    给定一个输入vector $ x \in \mathbb{R}^d $RMSNorm 计算

    $ {\displaystyle \operatorname{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2}} $

    并输出

    $ {\displaystyle \operatorname{RMSNorm}(x)_i = \frac{x_i}{\operatorname{RMS}(x) + \varepsilon}\, g_i} $

    其中 $ g \in \mathbb{R}^d $ 是可学习的增益向量(初始化为全 1),$ \varepsilon $ 是为数值稳定性而加入的小常数(通常为 $ 10^{-6} $$ 10^{-5} $)。该层不含可学习的bias,也不进行均值减法。

    对于一个 token 向量序列,RMSNorm 按 token 独立应用,方式与 Transformers 中的 LayerNorm 完全相同。以matrix形式表示,若 $ X \in \mathbb{R}^{n \times d} $$ n $ 个 token 向量按行堆叠,则

    $ {\displaystyle \operatorname{RMSNorm}(X) = \operatorname{diag}(g)\, X \oslash \sqrt{\frac{1}{d}(X \odot X)\mathbf{1} + \varepsilon}} $

    其中 $ \odot $$ \oslash $ 分别表示逐元素乘法和除法,$ \mathbf{1} $ 是全 1 向量。

    与层归一化的关系

    LayerNorm 计算均值 $ \mu = \frac{1}{d}\sum_i x_i $ 和标准差 $ \sigma = \sqrt{\frac{1}{d}\sum_i (x_i - \mu)^2} $,然后输出 $ \gamma \odot (x - \mu)/(\sigma + \varepsilon) + \beta $RMSNorm 恰好等价于将 $ \mu $ 强制设为零并移除 $ \beta $ 后的 LayerNorm。等价地说,RMSNorm 是限制于输入分布已经被中心化情形下的 LayerNorm,而在深层 Transformer 中,经过若干层训练后这一条件会近似成立。

    LayerNorm 在 $ d $vector上的算术开销大约是两次归约(均值和variance)加一次逐元素仿射变换;而 RMSNorm 只有一次归约(平方和)加一次逐元素缩放。在现代 GPU 上,单独看挂钟时间的节省很小,但在数百层、数万亿 token 的训练中累积起来就不可忽略。更重要的是,更简单的kernel使得编写融合实现以及在混合precision下保持数值稳定都更加容易。

    在 Transformer 中的位置

    RMSNorm 几乎总是用于 Pre-LayerNorm(前置归一化)残差结构:归一化层应用于每个子块的输入,而residual connection将未归一化的输入与子块输出相加。以 Self-attention 子块为例,示意如下:

    $ {\displaystyle y = x + \operatorname{Attention}(\operatorname{RMSNorm}(x))} $

    Feedforward Network 子块的处理方式与之类似。这与原始 TransformerPost-LayerNorm(后置归一化)布局形成对比,后者在残差相加之后才进行归一化。使用 RMSNorm 的前置归一化能够在无需学习率warmup等技巧的情况下稳定训练,并且能扩展到后置归一化通常会发散的规模,是当前几乎所有现代开放权重 LLM 的默认选择。

    通常还会在最后一个块的输出上、在反嵌入投影之前再应用一次最终的 RMSNorm。

    变体与扩展

    RMSNorm 有若干变体出现在文献和实际生产模型中。

    Partial RMSNorm 仅使用输入的前 $ k < d $ 个坐标来计算 RMS。其直觉是:高维向量的 RMS 高度集中在其均值附近,因此部分求和几乎同样精确。Zhang 和 Sennrich 报告称,在 $ k = d/8 $ 时质量损失可以忽略不计,且训练更快。Partial RMSNorm 在实践中并未被广泛采用,因为其绝对加速量很小,而现代注意力内核才是运行时的主导部分。

    Gated RMSNorm 将 RMSNorm 的输出与输入或另一个tensor的门控函数相乘。它被应用于一些 State Space Model 架构中,尤其是 Mamba 2,在其中用于在投影之前对 SSM 的输出进行门控。

    Group RMSNorm 将特征维度划分为 $ G $ 个组,每组使用各自的增益向量独立进行归一化,这与适用于convolutional networksGroup Normalization 类似。Group RMSNorm 被用于若干近期架构中的分组注意力头内部。

    QK NormalizationSelf-attention 内部,于dot product之前对查询向量和键向量应用 RMSNorm。这通过防止 softmax 之前的 logits 漂移到极端幅度,从而在超大规模训练中起到稳定作用,并被用于 Gemma 2、若干前沿模型以及 Vision Transformer ViT-22B 中。

    实现考量

    一种 PyTorch 的参考实现是深度学习中最简单的层之一:

    <syntaxhighlight lang="python"> class RMSNorm(nn.Module):

       def __init__(self, dim, eps=1e-6):
           super().__init__()
           self.weight = nn.Parameter(torch.ones(dim))
           self.eps = eps
    
       def forward(self, x):
           # x: (..., dim)
           rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
           return self.weight * (x * rms)
    

    </syntaxhighlight>

    生产级实现会将整个层融合到一个 CUDA kernel中以最小化内存流量,并在 fp32 下计算平方和——即使输入是 bf16fp16——以避免在 $ d $ 很大时(例如 LLaMA 70B 中的 8192)发生溢出。增益向量通常以与activations相同的precision存储,仅在融合kernel内部才被转换为 fp32。$ \varepsilon $ 加在平方根之内而非之外,以便在输入接近零时仍保持梯度行为良好。

    Apex, FlashAttention, and the major training frameworks all ship optimized RMSNorm kernels. The forward and backward passes are bandwidth-bound rather than compute-bound on accelerators.

    Empirical Performance

    Across machine translation, language modeling, and downstream evaluation tasks, RMSNorm matches or slightly exceeds LayerNorm in final task quality while training 5-10% faster end-to-end on common Transformer sizes. The original paper reported wins of 0.1-0.3 BLEU on WMT translation benchmarks at no quality cost, with ablations on RNN-based models showing similar trends. Subsequent large-scale studies, including the empirical work behind the T5 and LLaMA model families, found no scenario in which LayerNorm provided a clear quality advantage. Combined with its simpler kernel and good interaction with pre-norm residual connections, this evidence has made RMSNorm the default normalization choice in essentially every Transformer-based Language Model released since 2022.

    Limitations

    RMSNorm is not invariant to a constant additive shift of its input, so models that rely on encoding information in the mean of a representation could in principle behave differently under RMSNorm than under LayerNorm. In practice, Self-attention and Feedforward Network sub-blocks do not exploit the mean of their input, and the activations in pre-norm Transformers are approximately mean-zero by the second or third block, so the missing invariance is not measurable in real models. RMSNorm shares with LayerNorm the property that all coordinates of a vector influence the normalization of every other coordinate, which couples activations across the feature dimension and complicates certain forms of model surgery (such as pruning or per-feature quantization) more than would a per-feature normalizer like Batch Normalization.

    See Also

    References