Weight Standardization/es
| Article | |
|---|---|
| Topic area | deep learning |
| Prerequisites | Batch Normalization, Group Normalization, Convolutional Neural Network |
Resumen
La Estandarización de Pesos (Weight Standardization, WS) es una técnica de reparametrización para capas de redes neuronales que normaliza los pesos de una capa de convolución o lineal para que tengan media cero y varianza unitaria a lo largo de cada canal de salida antes de aplicarlos a la entrada. Introducida por Qiao, Wang, Liu, Shen y Yuille en 2019, está diseñada como complemento de los normalizadores basados en activaciones, como Group Normalization, Batch Normalization o Layer Normalization, y su principal motivación es recuperar el paisaje de pérdida favorable a la optimización de Batch Normalization en regímenes en los que las estadísticas por lote no son fiables, como el entrenamiento con micro-lotes o distribuido. A diferencia de las normalizaciones por lote y por grupo, que actúan sobre las activaciones y requieren estadísticas en ejecución o cómputo por muestra durante la inferencia, WS modifica únicamente los pesos y, por tanto, no introduce ningún coste adicional de inferencia más allá del paso hacia adelante de la capa subyacente.
WS suele combinarse con Group Normalization y se ha convertido en una receta estándar en detección de objetos, segmentación semántica y aprendizaje autosupervisado, donde la presión de memoria provocada por entradas de alta resolución obliga a usar tamaños de lote pequeños por dispositivo. También aparece en arquitecturas de visión modernas y modelos generativos que se benefician de una trayectoria de optimización más suave sin la discrepancia entre entrenamiento y prueba que introducen las estadísticas por lote.
Intuición
El éxito de Batch Normalization suele atribuirse no a su célebre reducción del desplazamiento de covariables interno, sino a un paisaje de pérdida más suave: acota la magnitud de las actualizaciones del gradiente y la constante de Lipschitz tanto de la pérdida como de su gradiente. Cuando el tamaño de lote se reduce, las estadísticas por lote se vuelven ruidosas, este efecto suavizador se degrada y la precisión cae bruscamente. La Estandarización de Pesos busca la misma propiedad de suavizado actuando sobre los propios parámetros en lugar de sobre las activaciones que producen.
La intuición es sencilla. Si las filas de una matriz de pesos tienen media y escala arbitrarias, pequeñas actualizaciones de los parámetros pueden producir cambios desproporcionadamente grandes en la salida de la capa. Al obligar a que cada filtro de salida tenga media cero y varianza unitaria sobre su fan-in, WS acota cuánto puede contribuir cada filtro a la salida antes de aplicarse cualquier función de activación. Esta cota, junto con un normalizador de activaciones posterior, mantiene la magnitud de las activaciones y de sus gradientes en un rango predecible a lo largo del entrenamiento.
Formulación
Sea $ W \in \mathbb{R}^{O \times I} $ los pesos de una capa, donde $ O $ es el número de canales de salida e $ I $ es el fan-in (para una convolución, $ I = C_{\text{in}} \cdot k_h \cdot k_w $). La Estandarización de Pesos sustituye $ W $ por una versión estandarizada $ \hat{W} $ definida por canal de salida:
$ {\displaystyle \hat{W}_{i, j} = \frac{W_{i, j} - \mu_i}{\sigma_i + \epsilon}, \quad \mu_i = \frac{1}{I} \sum_{j=1}^{I} W_{i, j}, \quad \sigma_i = \sqrt{\frac{1}{I} \sum_{j=1}^{I} (W_{i, j} - \mu_i)^2}} $
El paso hacia adelante utiliza entonces $ \hat{W} $ en lugar de $ W $:
$ {\displaystyle y = \hat{W} x + b} $
La estandarización es diferenciable, por lo que la retropropagación fluye a través de la normalización hasta los parámetros sin restricciones $ W $. WS por sí misma no introduce parámetros afines aprendibles; la ganancia y el sesgo suelen provenir de un normalizador de activaciones acoplado, como Group Normalization.
La transformación tiene dos efectos sobre los gradientes. En primer lugar, elimina la componente del gradiente que cambiaría la media de cada filtro, ya que los desplazamientos de la media quedan eliminados por el centrado. En segundo lugar, reescala el gradiente restante por $ 1/\sigma_i $, lo que actúa como un precondicionador por filtro. Qiao y sus colaboradores muestran que esto reduce la constante de Lipschitz de la pérdida y de su gradiente respecto a las activaciones, replicando el análisis de suavizado desarrollado previamente para Batch Normalization.
Entrenamiento e inferencia
WS se implementa como un envoltorio ligero alrededor del operador de convolución o lineal existente. Durante el entrenamiento, la estandarización se vuelve a calcular a partir de los pesos actuales en cada paso hacia adelante; los parámetros almacenados permanecen sin restricciones, y el optimizador (por ejemplo, descenso de gradiente estocástico con momento o Adam) los actualiza como de costumbre. Como la normalización es puramente función de los pesos, no se requieren estadísticas en ejecución, ni sincronización entre dispositivos, ni una división de comportamiento entre entrenamiento y evaluación.
En inferencia, los pesos estandarizados pueden recalcularse al vuelo o, más comúnmente, plegarse en la capa una sola vez y almacenarse, de modo que el modelo desplegado tenga exactamente el mismo perfil de cómputo y memoria que una convolución corriente. Cuando WS se combina con Group Normalization, la transformación combinada de normalizador-afín también puede fusionarse en los pesos de la convolución y en el sesgo para el despliegue, sin sobrecarga alguna.
WS interactúa de forma limpia con Weight Decay: dado que los gradientes respecto a la media y la escala se proyectan fuera, la caída de pesos aplicada a los parámetros sin restricciones reduce de hecho solo las direcciones que influyen en los pesos estandarizados, y los profesionales suelen mantener inalterados los coeficientes de decaimiento al añadir WS a una receta existente.
Variantes
Existen varias variantes que extienden o modifican el esquema básico. La Centered Weight Normalization centra pero no reescala; esto preserva el espíritu de Weight Normalization eliminando la media. La Scaled Weight Standardization, usada en la familia NFNet, multiplica los pesos estandarizados por una ganancia fija que compensa la varianza perdida a través de las no linealidades, lo que permite entrenar redes sin ningún normalizador de activaciones. La Equivariant Weight Standardization adapta WS a convoluciones equivariantes por grupos estandarizando dentro de cada órbita del grupo de simetría en lugar de hacerlo sobre el fan-in completo. Por último, varios autores aplican WS solo a un subconjunto de capas, excluyendo típicamente las convoluciones depthwise, donde el pequeño fan-in hace que las estadísticas por canal sean poco fiables.
Comparaciones
WS está estrechamente relacionada con Weight Normalization, aunque es distinta. La Weight Normalization desacopla la magnitud de cada filtro de su dirección escribiendo $ w = g \cdot v / \lVert v \rVert $ con un escalar aprendible $ g $; WS, en cambio, también resta la media y utiliza la desviación estándar empírica como normalizador, lo que es precisamente lo que produce el efecto de suavizado del gradiente. En comparación con Batch Normalization, WS no depende de las estadísticas por lote y, por tanto, no se degrada en regímenes de micro-lote o de acumulación; en comparación con Group Normalization por sí sola, cierra gran parte de la brecha residual con BN en tamaños de lote pequeños cuando se utiliza junto con GN. En comparación con Layer Normalization en transformadores, WS rara vez se usa porque LN ya opera por muestra y las matrices de pesos del matmul tienen una estructura estadística distinta de la de los filtros convolucionales.
Limitaciones
La técnica resulta más útil cuando el fan-in $ I $ es moderadamente grande; en capas con fan-in pequeño, como las convoluciones puntuales sobre canales estrechos o, especialmente, las convoluciones depthwise donde $ I = k_h \cdot k_w $, la media y la varianza por canal se estiman a partir de muy pocos pesos y la estandarización puede convertirse en una fuente de ruido en lugar de suavizado. WS también supone que los filtros de media cero son un sesgo inductivo deseable, lo cual es empíricamente cierto para las convoluciones sobre imágenes pero menos evidente en dominios donde el signo de la media tiene contenido semántico. Por último, aunque WS elimina la discrepancia entre entrenamiento y prueba de Batch Normalization, no elimina por sí misma la necesidad de un normalizador de activaciones: la mayoría de los resultados publicados que alcanzan precisión de vanguardia combinan WS con Group Normalization o utilizan el diseño dedicado de NFNet en lugar de prescindir por completo de la normalización de activaciones.
Referencias
- ↑ Qiao, S., Wang, H., Liu, C., Shen, W., Yuille, A. Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. arXiv:1903.10520, 2019.
- ↑ Brock, A., De, S., Smith, S. L., Simonyan, K. High-Performance Large-Scale Image Recognition Without Normalization. Proceedings of the 38th International Conference on Machine Learning, 2021.
- ↑ Salimans, T., Kingma, D. P. Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks. Advances in Neural Information Processing Systems 29, 2016.
- ↑ Santurkar, S., Tsipras, D., Ilyas, A., Madry, A. How Does Batch Normalization Help Optimization? Advances in Neural Information Processing Systems 31, 2018.
- ↑ Wu, Y., He, K. Group Normalization. Proceedings of the European Conference on Computer Vision (ECCV), 2018.