Weight Standardization/zh
| Article | |
|---|---|
| Topic area | deep learning |
| Prerequisites | Batch Normalization, Group Normalization, Convolutional Neural Network |
概述
权重标准化(Weight Standardization,WS)是一种针对神经网络层的重参数化技术,它在将卷积层或线性层的权重应用于输入之前,将其每个输出通道上的权重归一化为零均值和单位方差。该方法由 Qiao、Wang、Liu、Shen 和 Yuille 于 2019 年提出,旨在作为基于激活值的归一化方法(如 Group Normalization、Batch Normalization 或 Layer Normalization)的补充。其主要动机是在批统计量不可靠的情形下(例如微批量或分布式训练)恢复 Batch Normalization 那种有利于优化的损失曲面。与作用于激活值、需要在推理时维护运行统计量或逐样本计算的批归一化和组归一化不同,WS 仅修改权重,因此除底层层本身的前向传播外不引入任何额外的推理开销。
WS 最常与 Group Normalization 搭配使用,已经成为目标检测、语义分割和自监督学习中的标准做法,因为高分辨率输入带来的显存压力迫使每个设备只能使用较小的批大小。它也出现在现代视觉架构和生成模型中,这些模型受益于更平滑的优化轨迹,同时避免批统计量带来的训练与测试之间的不一致。
直观理解
Batch Normalization 的成功通常并不归因于其广为引用的内部协变量偏移的减少,而是归因于更平滑的损失曲面:它限制了梯度更新的幅度,以及损失及其梯度的利普希茨常数。当批大小减小时,批统计量会变得嘈杂,这种平滑效应随之退化,准确率也会骤降。权重标准化通过作用于参数本身而非其产生的激活值,来寻求同样的平滑性质。
直观上很简单。如果权重矩阵的各行具有任意的均值和尺度,那么参数的微小更新可能会在层的输出中产生不成比例的巨大变化。通过强制每个输出滤波器在其 fan-in 上具有零均值和单位方差,WS 限制了每个滤波器在应用任何激活函数之前对输出所能贡献的大小。这一约束与下游的激活归一化器相结合,使激活值及其梯度的幅度在整个训练过程中保持在可预测的范围内。
公式表述
设 $ W \in \mathbb{R}^{O \times I} $ 表示某一层的权重,其中 $ O $ 是输出通道数,$ I $ 是 fan-in(对于卷积而言,$ I = C_{\text{in}} \cdot k_h \cdot k_w $)。权重标准化将 $ W $ 替换为按每个输出通道定义的标准化版本 $ \hat{W} $:
$ {\displaystyle \hat{W}_{i, j} = \frac{W_{i, j} - \mu_i}{\sigma_i + \epsilon}, \quad \mu_i = \frac{1}{I} \sum_{j=1}^{I} W_{i, j}, \quad \sigma_i = \sqrt{\frac{1}{I} \sum_{j=1}^{I} (W_{i, j} - \mu_i)^2}} $
随后前向传播使用 $ \hat{W} $ 替代 $ W $:
$ {\displaystyle y = \hat{W} x + b} $
标准化运算是可微的,因此反向传播会沿着归一化过程流向不受约束的参数 $ W $。WS 本身不引入可学习的仿射参数;增益和偏置通常来自配套的激活归一化器,例如 Group Normalization。
该变换对梯度有两个影响。其一,它消除了会改变每个滤波器均值的那部分梯度分量,因为均值方向的偏移已被中心化操作剔除。其二,它将剩余梯度按 $ 1/\sigma_i $ 重新缩放,这相当于一个逐滤波器的预条件器。Qiao 等人证明,这降低了损失及其相对于激活值的梯度的利普希茨常数,与此前为 Batch Normalization 提出的平滑性分析相一致。
训练与推理
WS 实现为对现有卷积或线性算子的一层薄封装。在训练期间,每次前向传播都根据当前权重重新计算标准化;存储的参数保持不受约束,优化器(例如带动量的随机梯度下降或 Adam)按常规方式对其进行更新。由于归一化完全是权重的函数,无需运行统计量、无需跨设备同步,也无需在训练与评估之间设置不同的行为分支。
在推理时,标准化后的权重既可以即时重新计算,也可以更常见地一次性折叠进该层并保存,使得部署后的模型在计算量和显存占用上与普通卷积完全一致。当 WS 与 Group Normalization 搭配使用时,归一化与仿射变换的组合也可以融合进卷积权重和偏置中以用于部署,从而完全没有额外开销。
WS 与 Weight Decay 配合得很自然:由于关于均值和尺度的梯度被投影掉,应用于不受约束参数的权重衰减实际上只会收缩那些会影响标准化权重的方向,因此实践者在将 WS 加入既有训练配方时通常无需调整衰减系数。
变体
有几种变体在基础方案之上进行扩展或修改。Centered Weight Normalization 进行中心化但不进行重新缩放,这在保留 Weight Normalization 精神的同时去除了均值。在 NFNet 系列中使用的 Scaled Weight Standardization,会将标准化后的权重乘以一个固定增益,以补偿在非线性中损失的方差,从而使网络可以在完全不使用任何激活归一化器的情况下进行训练。Equivariant Weight Standardization 将 WS 推广到群等变卷积,在对称群的每个轨道内部进行标准化,而不是在整个 fan-in 上进行。最后,部分作者只对一部分层应用 WS,通常排除深度可分卷积层,因为这些层的 fan-in 很小,导致逐通道的统计量并不可靠。
对比
WS 与 Weight Normalization 关系密切但又有所不同。Weight Normalization 通过写成 $ w = g \cdot v / \lVert v \rVert $ 并引入可学习的标量 $ g $,将每个滤波器的幅度与方向解耦;相比之下,WS 还要减去均值并以经验标准差作为归一化因子,这正是其产生梯度平滑效应的原因。与 Batch Normalization 相比,WS 不依赖批统计量,因此在微批量或梯度累积场景下不会退化;与单独使用 Group Normalization 相比,它在与GN联合使用时,可在小批大小下大幅缩小与BN之间的剩余差距。与Transformer中的 Layer Normalization 相比,WS 很少被使用,因为LN本身就是逐样本运算,且矩阵乘权重矩阵的统计结构与卷积滤波器有所不同。
局限性
该技术在 fan-in $ I $ 适中较大时最为有效;对于 fan-in 较小的层,例如作用于窄通道的逐点卷积,尤其是 $ I = k_h \cdot k_w $ 的深度可分卷积,逐通道的均值和方差是基于极少量权重估计出来的,标准化反而可能成为噪声来源而非平滑机制。WS 还假设零均值的滤波器是一种理想的归纳偏置,这一假设在图像卷积中经验上成立,但在均值符号本身具有语义含义的领域中就不那么显然。最后,尽管 WS 消除了 Batch Normalization 中训练与测试的不一致,但它本身并不能消除对激活归一化器的需要:大多数达到最新水准准确率的报告结果都将 WS 与 Group Normalization 联合使用,或采用专门设计的 NFNet,而不是完全抛弃激活归一化。
参考文献
- ↑ Qiao, S., Wang, H., Liu, C., Shen, W., Yuille, A. Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. arXiv:1903.10520, 2019.
- ↑ Brock, A., De, S., Smith, S. L., Simonyan, K. High-Performance Large-Scale Image Recognition Without Normalization. Proceedings of the 38th International Conference on Machine Learning, 2021.
- ↑ Salimans, T., Kingma, D. P. Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks. Advances in Neural Information Processing Systems 29, 2016.
- ↑ Santurkar, S., Tsipras, D., Ilyas, A., Madry, A. How Does Batch Normalization Help Optimization? Advances in Neural Information Processing Systems 31, 2018.
- ↑ Wu, Y., He, K. Group Normalization. Proceedings of the European Conference on Computer Vision (ECCV), 2018.