Weight Standardization/zh
| Article | |
|---|---|
| Topic area | deep learning |
| Prerequisites | Batch Normalization, Group Normalization, Convolutional Neural Network |
概述
權重標準化(Weight Standardization,WS)是一種針對神經網絡層的重參數化技術,它在將卷積層或線性層的權重應用於輸入之前,將其每個輸出通道上的權重歸一化為零均值和單位方差。該方法由 Qiao、Wang、Liu、Shen 和 Yuille 於 2019 年提出,旨在作為基於激活值的歸一化方法(如 Group Normalization、Batch Normalization 或 Layer Normalization)的補充。其主要動機是在批統計量不可靠的情形下(例如微批量或分布式訓練)恢復 Batch Normalization 那種有利於優化的損失曲面。與作用於激活值、需要在推理時維護運行統計量或逐樣本計算的批歸一化和組歸一化不同,WS 僅修改權重,因此除底層層本身的前向傳播外不引入任何額外的推理開銷。
WS 最常與 Group Normalization 搭配使用,已經成為目標檢測、語義分割和自監督學習中的標準做法,因為高解析度輸入帶來的顯存壓力迫使每個設備只能使用較小的批大小。它也出現在現代視覺架構和生成模型中,這些模型受益於更平滑的優化軌跡,同時避免批統計量帶來的訓練與測試之間的不一致。
直觀理解
Batch Normalization 的成功通常並不歸因於其廣為引用的內部協變量偏移的減少,而是歸因於更平滑的損失曲面:它限制了梯度更新的幅度,以及損失及其梯度的利普希茨常數。當批大小減小時,批統計量會變得嘈雜,這種平滑效應隨之退化,準確率也會驟降。權重標準化通過作用於參數本身而非其產生的激活值,來尋求同樣的平滑性質。
直觀上很簡單。如果權重矩陣的各行具有任意的均值和尺度,那麼參數的微小更新可能會在層的輸出中產生不成比例的巨大變化。通過強制每個輸出濾波器在其 fan-in 上具有零均值和單位方差,WS 限制了每個濾波器在應用任何激活函數之前對輸出所能貢獻的大小。這一約束與下游的激活歸一化器相結合,使激活值及其梯度的幅度在整個訓練過程中保持在可預測的範圍內。
公式表述
設 $ W \in \mathbb{R}^{O \times I} $ 表示某一層的權重,其中 $ O $ 是輸出通道數,$ I $ 是 fan-in(對於卷積而言,$ I = C_{\text{in}} \cdot k_h \cdot k_w $)。權重標準化將 $ W $ 替換為按每個輸出通道定義的標準化版本 $ \hat{W} $:
$ {\displaystyle \hat{W}_{i, j} = \frac{W_{i, j} - \mu_i}{\sigma_i + \epsilon}, \quad \mu_i = \frac{1}{I} \sum_{j=1}^{I} W_{i, j}, \quad \sigma_i = \sqrt{\frac{1}{I} \sum_{j=1}^{I} (W_{i, j} - \mu_i)^2}} $
隨後前向傳播使用 $ \hat{W} $ 替代 $ W $:
$ {\displaystyle y = \hat{W} x + b} $
標準化運算是可微的,因此反向傳播會沿著歸一化過程流向不受約束的參數 $ W $。WS 本身不引入可學習的仿射參數;增益和偏置通常來自配套的激活歸一化器,例如 Group Normalization。
該變換對梯度有兩個影響。其一,它消除了會改變每個濾波器均值的那部分梯度分量,因為均值方向的偏移已被中心化操作剔除。其二,它將剩餘梯度按 $ 1/\sigma_i $ 重新縮放,這相當於一個逐濾波器的預條件器。Qiao 等人證明,這降低了損失及其相對於激活值的梯度的利普希茨常數,與此前為 Batch Normalization 提出的平滑性分析相一致。
訓練與推理
WS 實現為對現有卷積或線性算子的一層薄封裝。在訓練期間,每次前向傳播都根據當前權重重新計算標準化;存儲的參數保持不受約束,優化器(例如帶動量的隨機梯度下降或 Adam)按常規方式對其進行更新。由於歸一化完全是權重的函數,無需運行統計量、無需跨設備同步,也無需在訓練與評估之間設置不同的行為分支。
在推理時,標準化後的權重既可以即時重新計算,也可以更常見地一次性摺疊進該層並保存,使得部署後的模型在計算量和顯存占用上與普通卷積完全一致。當 WS 與 Group Normalization 搭配使用時,歸一化與仿射變換的組合也可以融合進卷積權重和偏置中以用於部署,從而完全沒有額外開銷。
WS 與 Weight Decay 配合得很自然:由於關於均值和尺度的梯度被投影掉,應用於不受約束參數的權重衰減實際上只會收縮那些會影響標準化權重的方向,因此實踐者在將 WS 加入既有訓練配方時通常無需調整衰減係數。
變體
有幾種變體在基礎方案之上進行擴展或修改。Centered Weight Normalization 進行中心化但不進行重新縮放,這在保留 Weight Normalization 精神的同時去除了均值。在 NFNet 系列中使用的 Scaled Weight Standardization,會將標準化後的權重乘以一個固定增益,以補償在非線性中損失的方差,從而使網絡可以在完全不使用任何激活歸一化器的情況下進行訓練。Equivariant Weight Standardization 將 WS 推廣到群等變卷積,在對稱群的每個軌道內部進行標準化,而不是在整個 fan-in 上進行。最後,部分作者只對一部分層應用 WS,通常排除深度可分卷積層,因為這些層的 fan-in 很小,導致逐通道的統計量並不可靠。
對比
WS 與 Weight Normalization 關係密切但又有所不同。Weight Normalization 通過寫成 $ w = g \cdot v / \lVert v \rVert $ 並引入可學習的標量 $ g $,將每個濾波器的幅度與方向解耦;相比之下,WS 還要減去均值並以經驗標準差作為歸一化因子,這正是其產生梯度平滑效應的原因。與 Batch Normalization 相比,WS 不依賴批統計量,因此在微批量或梯度累積場景下不會退化;與單獨使用 Group Normalization 相比,它在與GN聯合使用時,可在小批大小下大幅縮小與BN之間的剩餘差距。與Transformer中的 Layer Normalization 相比,WS 很少被使用,因為LN本身就是逐樣本運算,且矩陣乘權重矩陣的統計結構與卷積濾波器有所不同。
局限性
該技術在 fan-in $ I $ 適中較大時最為有效;對於 fan-in 較小的層,例如作用於窄通道的逐點卷積,尤其是 $ I = k_h \cdot k_w $ 的深度可分卷積,逐通道的均值和方差是基於極少量權重估計出來的,標準化反而可能成為噪聲來源而非平滑機制。WS 還假設零均值的濾波器是一種理想的歸納偏置,這一假設在圖像卷積中經驗上成立,但在均值符號本身具有語義含義的領域中就不那麼顯然。最後,儘管 WS 消除了 Batch Normalization 中訓練與測試的不一致,但它本身並不能消除對激活歸一化器的需要:大多數達到最新水準準確率的報告結果都將 WS 與 Group Normalization 聯合使用,或採用專門設計的 NFNet,而不是完全拋棄激活歸一化。
參考文獻
- ↑ Qiao, S., Wang, H., Liu, C., Shen, W., Yuille, A. Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. arXiv:1903.10520, 2019.
- ↑ Brock, A., De, S., Smith, S. L., Simonyan, K. High-Performance Large-Scale Image Recognition Without Normalization. Proceedings of the 38th International Conference on Machine Learning, 2021.
- ↑ Salimans, T., Kingma, D. P. Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks. Advances in Neural Information Processing Systems 29, 2016.
- ↑ Santurkar, S., Tsipras, D., Ilyas, A., Madry, A. How Does Batch Normalization Help Optimization? Advances in Neural Information Processing Systems 31, 2018.
- ↑ Wu, Y., He, K. Group Normalization. Proceedings of the European Conference on Computer Vision (ECCV), 2018.