RMSNorm/zh

    From Marovi AI
    This page is a translated version of the page RMSNorm and the translation is 77% complete.
    Other languages:
    Article
    Topic area neural-network-components
    Prerequisites Layer Normalization, Transformer


    概述

    均方根層歸一化RMSNorm)是一種用於深度神經網絡的歸一化層,由 Zhang 和 Sennrich 於 2019 年提出,作為 Layer Normalization 的簡化版本。RMSNorm 通過輸入vector的均方根(RMS)對其進行重新縮放,並對每個特徵應用一個可學習的增益,但與 LayerNorm 不同的是,它省略了均值中心化(重新中心化)步驟以及加性bias。其結果是一個算術開銷約為 LayerNorm 一半的層,同時保留了 LayerNorm 的大部分訓練穩定性優勢。

    RMSNorm 已成為現代 Transformer 架構(尤其是大型 Language Model)的標準組件。它是 T5、LLaMA、LLaMA 2、LLaMA 3、Mistral、Gemma、Qwen 以及許多其他開放權重模型所使用的歸一化層。其流行stems於若干因素的結合:實現簡單、計算速度略快、與 Pre-LayerNorm 佈局配合時訓練穩定性略有提升,以及一項經驗觀察——對activations進行重新中心化幾乎不會改變下游任務的質量。

    動機

    Layer Normalization 通過減去均值並除以標準差來標準化一個vector $ x \in \mathbb{R}^d $,然後應用一個可學習的縮放因子 $ \gamma $ 和偏移 $ \beta $。重新中心化步驟(減去均值)最初的動機來自與 Batch Normalization 的類比,後者沿批次維度對每個特徵進行中心化。Zhang 和 Sennrich 觀察到,LayerNorm 在深層網絡中的實際收益幾乎全部來自重新縮放而非重新中心化:該層的作用是將activation的幅值保持在有界範圍內,從而使梯度既不消失也不爆炸。均值減法步驟需要額外遍歷向量一次,引入一個額外參數($ \beta $),卻對優化幾乎沒有貢獻。

    RMSNorm 回答的正是這樣一個問題:"如果我們只保留重新縮放部分會怎樣?" 它對輸入的逐向量重新縮放保持不變,但與 LayerNorm 不同,它對常數偏移並不保持不變。經驗表明,在現代 Transformer 的訓練中,這一缺失的不變性並未帶來任何可測量的代價。

    公式

    給定一個輸入vector $ x \in \mathbb{R}^d $RMSNorm 計算

    $ {\displaystyle \operatorname{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2}} $

    並輸出

    $ {\displaystyle \operatorname{RMSNorm}(x)_i = \frac{x_i}{\operatorname{RMS}(x) + \varepsilon}\, g_i} $

    其中 $ g \in \mathbb{R}^d $ 是可學習的增益向量(初始化為全 1),$ \varepsilon $ 是為數值穩定性而加入的小常數(通常為 $ 10^{-6} $$ 10^{-5} $)。該層不含可學習的bias,也不進行均值減法。

    對於一個 token 向量序列,RMSNorm 按 token 獨立應用,方式與 Transformers 中的 LayerNorm 完全相同。以matrix形式表示,若 $ X \in \mathbb{R}^{n \times d} $$ n $ 個 token 向量按行堆疊,則

    $ {\displaystyle \operatorname{RMSNorm}(X) = \operatorname{diag}(g)\, X \oslash \sqrt{\frac{1}{d}(X \odot X)\mathbf{1} + \varepsilon}} $

    其中 $ \odot $$ \oslash $ 分別表示逐元素乘法和除法,$ \mathbf{1} $ 是全 1 向量。

    與層歸一化的關係

    LayerNorm 計算均值 $ \mu = \frac{1}{d}\sum_i x_i $ 和標準差 $ \sigma = \sqrt{\frac{1}{d}\sum_i (x_i - \mu)^2} $,然後輸出 $ \gamma \odot (x - \mu)/(\sigma + \varepsilon) + \beta $RMSNorm 恰好等價於將 $ \mu $ 強制設為零並移除 $ \beta $ 後的 LayerNorm。等價地說,RMSNorm 是限制於輸入分佈已經被中心化情形下的 LayerNorm,而在深層 Transformer 中,經過若干層訓練後這一條件會近似成立。

    LayerNorm 在 $ d $vector上的算術開銷大約是兩次歸約(均值和variance)加一次逐元素仿射變換;而 RMSNorm 只有一次歸約(平方和)加一次逐元素縮放。在現代 GPU 上,單獨看掛鍾時間的節省很小,但在數百層、數萬億 token 的訓練中累積起來就不可忽略。更重要的是,更簡單的kernel使得編寫融合實現以及在混合precision下保持數值穩定都更加容易。

    在 Transformer 中的位置

    RMSNorm 幾乎總是用於 Pre-LayerNorm(前置歸一化)殘差結構:歸一化層應用於每個子塊的輸入,而residual connection將未歸一化的輸入與子塊輸出相加。以 Self-attention 子塊為例,示意如下:

    $ {\displaystyle y = x + \operatorname{Attention}(\operatorname{RMSNorm}(x))} $

    Feedforward Network 子塊的處理方式與之類似。這與原始 TransformerPost-LayerNorm(後置歸一化)佈局形成對比,後者在殘差相加之後才進行歸一化。使用 RMSNorm 的前置歸一化能夠在無需學習率warmup等技巧的情況下穩定訓練,並且能擴展到後置歸一化通常會發散的規模,是當前幾乎所有現代開放權重 LLM 的默認選擇。

    通常還會在最後一個塊的輸出上、在反嵌入投影之前再應用一次最終的 RMSNorm。

    變體與擴展

    RMSNorm 有若干變體出現在文獻和實際生產模型中。

    Partial RMSNorm 僅使用輸入的前 $ k < d $ 個坐標來計算 RMS。其直覺是:高維向量的 RMS 高度集中在其均值附近,因此部分求和幾乎同樣精確。Zhang 和 Sennrich 報告稱,在 $ k = d/8 $ 時質量損失可以忽略不計,且訓練更快。Partial RMSNorm 在實踐中並未被廣泛採用,因為其絕對加速量很小,而現代注意力內核才是運行時的主導部分。

    Gated RMSNorm 將 RMSNorm 的輸出與輸入或另一個tensor的門控函數相乘。它被應用於一些 State Space Model 架構中,尤其是 Mamba 2,在其中用於在投影之前對 SSM 的輸出進行門控。

    Group RMSNorm 將特徵維度劃分為 $ G $ 個組,每組使用各自的增益向量獨立進行歸一化,這與適用於convolutional networksGroup Normalization 類似。Group RMSNorm 被用於若干近期架構中的分組注意力頭內部。

    QK NormalizationSelf-attention 內部,於dot product之前對查詢向量和鍵向量應用 RMSNorm。這通過防止 softmax 之前的 logits 漂移到極端幅度,從而在超大規模訓練中起到穩定作用,並被用於 Gemma 2、若干前沿模型以及 Vision Transformer ViT-22B 中。

    實現考量

    一種 PyTorch 的參考實現是深度學習中最簡單的層之一:

    <syntaxhighlight lang="python"> class RMSNorm(nn.Module):

       def __init__(self, dim, eps=1e-6):
           super().__init__()
           self.weight = nn.Parameter(torch.ones(dim))
           self.eps = eps
    
       def forward(self, x):
           # x: (..., dim)
           rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
           return self.weight * (x * rms)
    

    </syntaxhighlight>

    生產級實現會將整個層融合到一個 CUDA kernel中以最小化內存流量,並在 fp32 下計算平方和——即使輸入是 bf16fp16——以避免在 $ d $ 很大時(例如 LLaMA 70B 中的 8192)發生溢出。增益向量通常以與activations相同的precision存儲,僅在融合kernel內部才被轉換為 fp32。$ \varepsilon $ 加在平方根之內而非之外,以便在輸入接近零時仍保持梯度行為良好。

    Apex, FlashAttention, and the major training frameworks all ship optimized RMSNorm kernels. The forward and backward passes are bandwidth-bound rather than compute-bound on accelerators.

    Empirical Performance

    Across machine translation, language modeling, and downstream evaluation tasks, RMSNorm matches or slightly exceeds LayerNorm in final task quality while training 5-10% faster end-to-end on common Transformer sizes. The original paper reported wins of 0.1-0.3 BLEU on WMT translation benchmarks at no quality cost, with ablations on RNN-based models showing similar trends. Subsequent large-scale studies, including the empirical work behind the T5 and LLaMA model families, found no scenario in which LayerNorm provided a clear quality advantage. Combined with its simpler kernel and good interaction with pre-norm residual connections, this evidence has made RMSNorm the default normalization choice in essentially every Transformer-based Language Model released since 2022.

    Limitations

    RMSNorm is not invariant to a constant additive shift of its input, so models that rely on encoding information in the mean of a representation could in principle behave differently under RMSNorm than under LayerNorm. In practice, Self-attention and Feedforward Network sub-blocks do not exploit the mean of their input, and the activations in pre-norm Transformers are approximately mean-zero by the second or third block, so the missing invariance is not measurable in real models. RMSNorm shares with LayerNorm the property that all coordinates of a vector influence the normalization of every other coordinate, which couples activations across the feature dimension and complicates certain forms of model surgery (such as pruning or per-feature quantization) more than would a per-feature normalizer like Batch Normalization.

    See Also

    References