LAMB Optimizer/es
| Article | |
|---|---|
| Topic area | Machine Learning |
| Prerequisites | Adam Optimizer, Stochastic Gradient Descent, Gradient Descent |
Visión general
El optimizador LAMB (Layer-wise Adaptive Moments for Batch training) es un algoritmo de optimización estocástica de primer orden diseñado para permitir el entrenamiento de redes neuronales profundas con tamaños de mini-lote muy grandes sin pérdida de generalización. Introducido por You et al. en 2019, LAMB combina las estimaciones adaptativas por parámetro de los momentos de Adam con una razón de confianza por capa inspirada en LARS (Layer-wise Adaptive Rate Scaling). El algoritmo ganó notoriedad tras ser utilizado para reducir el tiempo de pre-entrenamiento de BERT de aproximadamente tres días en una única TPU pod a 76 minutos en 1024 chips TPUv3, igualando o superando la puntuación F1 publicada en el benchmark SQuAD.[1]
LAMB se sitúa en la intersección de dos preocupaciones de larga data en aprendizaje profundo: cómo aprovechar el paralelismo de los aceleradores modernos (que favorece lotes grandes) y cómo preservar la regularización implícita que se atribuye al SGD de lotes pequeños. Su intuición central es que, con una tasa de aprendizaje global fija, las capas con normas de pesos muy distintas reciben actualizaciones de magnitudes efectivas muy diferentes, y que reescalar la actualización de cada capa por una razón de confianza entre la norma de los pesos y la norma de la actualización restablece un entrenamiento estable a tamaños de lote grandes.
Motivación: entrenamiento con lotes grandes
El entrenamiento distribuido en paralelo de datos escala el tamaño de lote efectivo proporcionalmente al número de trabajadores. En principio, duplicar el tamaño de lote y la tasa de aprendizaje a la vez preserva la trayectoria por época del descenso de gradiente, heurística conocida como regla de escalado lineal. En la práctica, esta regla deja de cumplirse a partir de un tamaño de lote crítico que depende del problema: el entrenamiento diverge, se estanca en una pérdida peor o generaliza mal. Goyal et al. demostraron el escalado lineal para ResNet-50 sobre ImageNet hasta un tamaño de lote de 8192 utilizando SGD con momento, calentamiento y normalización cuidadosa,[2] pero la misma receta falló para los modelos Transformer entrenados con Adam.
LARS, propuesto antes por You et al., abordó el caso de SGD introduciendo una razón de confianza por capa que adapta la tasa de aprendizaje a la norma de pesos de cada capa.[3] LARS llevó el entrenamiento de ResNet-50 hasta un tamaño de lote de 32K. Sin embargo, LARS se basa en SGD con momento, y aplicarlo directamente al entrenamiento de Transformer con optimizadores adaptativos produjo resultados inferiores. LAMB extiende la idea de la razón de confianza a los métodos basados en momentos adaptativos, que son la elección de facto para el pre-entrenamiento de transformadores.
Formulación algorítmica
Sea $ \theta_t \in \mathbb{R}^d $ el vector de parámetros en el paso $ t $, particionado en $ L $ capas como $ \theta_t = (\theta_t^{(1)}, \dots, \theta_t^{(L)}) $. Dado un gradiente estocástico $ g_t = \nabla_\theta \ell(\theta_t; \xi_t) $ sobre un mini-lote $ \xi_t $, LAMB realiza la siguiente actualización.
Primero, mantiene medias móviles exponenciales del gradiente y de su cuadrado, idénticas a las de Adam:
$ {\displaystyle m_t = \beta_1 m_{t-1} + (1 - \beta_1)\, g_t,} $
$ {\displaystyle v_t = \beta_2 v_{t-1} + (1 - \beta_2)\, g_t \odot g_t,} $
con estimaciones corregidas por sesgo $ \hat m_t = m_t / (1 - \beta_1^t) $ y $ \hat v_t = v_t / (1 - \beta_2^t) $.
Segundo, forma una actualización por coordenada al estilo Adam aumentada con decaimiento de pesos desacoplado $ \lambda $:
$ {\displaystyle r_t = \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} + \lambda\, \theta_t.} $
Tercero, y de forma crucial, reescala la actualización de cada capa $ i $ mediante una razón de confianza por capa:
$ {\displaystyle \theta_{t+1}^{(i)} = \theta_t^{(i)} - \eta_t \cdot \frac{\phi\!\left(\| \theta_t^{(i)} \|\right)}{\| r_t^{(i)} \|}\, r_t^{(i)},} $
donde $ \| \cdot \| $ denota la norma L2 restringida a los parámetros de la capa $ i $, $ \eta_t $ es la tasa de aprendizaje global y $ \phi: \mathbb{R}_{\ge 0} \to \mathbb{R}_{>0} $ es una función de escalado. En la implementación canónica, $ \phi(x) = x $ (la identidad), opcionalmente recortada a un rango como $ [\phi_{\min}, \phi_{\max}] $ para evitar actualizaciones extremas en capas con normas de pesos muy pequeñas o muy grandes.
La razón de confianza $ \| \theta^{(i)} \| / \| r^{(i)} \| $ garantiza que, en esperanza, el cambio relativo $ \| \Delta \theta^{(i)} \| / \| \theta^{(i)} \| $ sea igual a $ \eta_t $, independientemente de la escala absoluta de la capa. Las capas cuya actualización Adam sería desproporcionadamente grande (en relación con los pesos actuales) se amortiguan, y las actualizaciones conservadoramente pequeñas se amplifican.
Pseudocódigo
input: learning rate eta, betas (b1, b2), epsilon, weight decay lambda
init: theta_0, m_0 = 0, v_0 = 0
for t = 1, 2, ... do
sample mini-batch, compute g_t
m_t = b1 * m_{t-1} + (1 - b1) * g_t
v_t = b2 * v_{t-1} + (1 - b2) * g_t * g_t
m_hat = m_t / (1 - b1**t)
v_hat = v_t / (1 - b2**t)
r_t = m_hat / (sqrt(v_hat) + epsilon) + lambda * theta_{t-1}
for each layer i do
w_norm = ||theta_{t-1}^(i)||
g_norm = ||r_t^(i)||
if w_norm > 0 and g_norm > 0:
trust = phi(w_norm) / g_norm
else:
trust = 1
theta_t^(i) = theta_{t-1}^(i) - eta * trust * r_t^(i)
end for
end for
Consideraciones prácticas
LAMB es más eficaz cuando se combina con un calendario de tasa de aprendizaje que incluya una fase de calentamiento. La receta original de BERT utilizaba un calentamiento lineal durante los primeros miles de pasos seguido de una caída polinómica; la tasa de aprendizaje pico es notablemente más alta que el rango típico de Adam, a menudo $ 10^{-2} $ o más para el pre-entrenamiento de transformadores, porque la razón de confianza absorbe la magnitud absoluta de las actualizaciones.
El decaimiento de pesos se desacopla del gradiente, en el sentido de AdamW; incluir el decaimiento dentro de $ g_t $ lo acoplaría a la normalización del segundo momento y anularía gran parte del beneficio. Los parámetros de sesgo y de normalización (por ejemplo, las escalas y desplazamientos de batch norm o de layer-norm) suelen quedar exentos tanto del decaimiento de pesos como del reescalado por razón de confianza, ya que sus normas son pequeñas y la razón de confianza puede volverse numéricamente inestable.
La estabilidad numérica también exige protegerse frente a capas con norma cero en la inicialización o tras la poda. La implementación de referencia recurre a una razón de confianza unitaria siempre que la norma de los pesos o la norma de la actualización sean cero.
Comparaciones
Comparado con Adam, LAMB introduce sólo una sobrecarga de factor constante modesta por paso, dominada por los cálculos de norma por capa. Sus beneficios sobre Adam son insignificantes con tamaños de lote pequeños, donde Adam ya converge bien; la diferencia aparece cuando el tamaño de lote supera varios miles de ejemplos. Comparado con LARS, LAMB hereda la adaptación por capa de LARS pero usa estimaciones de momentos adaptativos en lugar de momento, lo que lo hace más adecuado para arquitecturas tipo transformador en las que las magnitudes de los gradientes varían en órdenes de magnitud entre capas.
El artículo original presenta un análisis de convergencia bajo hipótesis estándar de suavidad y varianza acotada, mostrando una tasa $ O(1/\sqrt{T}) $ hacia un punto estacionario para objetivos no convexos. El análisis también muestra en qué sentido la actualización de LAMB es el único escalado que mantiene la norma de actualización por capa proporcional a la norma de pesos, lo que motiva el algoritmo más allá de razones puramente empíricas.
Limitaciones
LAMB no produce el mismo modelo entrenado que Adam con lotes pequeños. Las diferencias de regularización implícita entre lotes pequeños y grandes no se neutralizan por completo, y la exactitud en tareas posteriores puede ser ligeramente inferior con tamaños de lote muy grandes pese a igualar la pérdida de pre-entrenamiento. El algoritmo también expone hiperparámetros adicionales: el rango de recorte de la razón de confianza, las exenciones por grupo de parámetros y la elección de $ \phi $. En la práctica, los valores por defecto de la receta de BERT se transfieren bien a otros pre-entrenamientos de transformador, pero requieren reajuste para cargas de trabajo de visión o aprendizaje por refuerzo.
Empíricamente, LAMB ha tenido más éxito en el pre-entrenamiento de transformadores y en aprendizaje supervisado a gran escala. Sus ventajas sobre un Adam bien ajustado son menores en el ajuste fino, donde los lotes suelen ser modestos y el paisaje de optimización cerca de una inicialización pre-entrenada está bien condicionado. Para modelos muy pequeños o datos tabulares, suelen bastar optimizadores más sencillos.