Decoupled Weight Decay Regularization/paper/es

    From Marovi AI
    This page is a translated version of the page Decoupled Weight Decay Regularization/paper and the translation is 100% complete.
    Other languages:
    SummarySource
    Research Paper
    Authors Ilya Loshchilov; Frank Hutter
    Year 2017
    Topic area Machine Learning
    Difficulty Research
    arXiv 1711.05101
    PDF Download PDF

    Decoupled Weight Decay Regularization

    Ilya Loshchilov & Frank Hutter
    / Universidad de Friburgo
    / Friburgo, Alemania,
    / {ilya,fh}@cs.uni-freiburg.de

    Resumen

    La regularización L2 y la regularización por weight decay son equivalentes para el descenso de gradiente estocástico estándar (cuando se reescala por la learning rate), pero, como demostramos, este no es el caso para algoritmos de gradiente adaptativo, como Adam. Si bien las implementaciones habituales de estos algoritmos emplean regularización L2 (a menudo llamándola «weight decay» de un modo que puede inducir a error debido a la inequivalencia que aquí ponemos al descubierto), proponemos una modificación simple para recuperar la formulación original de la regularización por weight decay desacoplando la weight decay de los pasos de optimización realizados respecto a la función de pérdida. Aportamos evidencia empírica de que nuestra modificación propuesta (i) desacopla la elección óptima del factor de weight decay del valor de la learning rate tanto para SGD estándar como para Adam, y (ii) mejora sustancialmente el rendimiento de generalización de Adam, permitiéndole competir con SGD con momentum en datasets de clasificación de imágenes (en los que antes solía ser superado por este último). Nuestra weight decay desacoplada propuesta ya ha sido adoptada por muchos investigadores, y la comunidad la ha implementado en TensorFlow y PyTorch; el código fuente completo de nuestros experimentos está disponible en https://github.com/loshchil/AdamW-and-SGDW

    1 Introducción

    Los métodos de gradiente adaptativos, como AdaGrad (Duchi et al., 2011), RMSProp (Tieleman & Hinton, 2012), Adam (Kingma & Ba, 2014) y, más recientemente, AMSGrad (Reddi et al., 2018), se han convertido en el método por defecto para entrenar redes neuronales feed-forward y recurrentes (Xu et al., 2015; Radford et al., 2015). No obstante, los resultados estado del arte en datasets populares de clasificación de imágenes, como CIFAR-10 y CIFAR-100 Krizhevsky (2009), siguen obteniéndose aplicando SGD con momentum (Gastaldi, 2017; Cubuk et al., 2018). Además, Wilson et al. (2017) sugirieron que los métodos de gradiente adaptativos no generalizan tan bien como SGD con momentum cuando se evalúan sobre un conjunto diverso de tareas de aprendizaje profundo, como clasificación de imágenes, modelado de lenguaje a nivel de carácter y constituency parsing. Se han investigado distintas hipótesis sobre el origen de esta peor generalización, como la presencia de mínimos locales agudos (Keskar et al., 2016; Dinh et al., 2017) y problemas inherentes a los métodos de gradiente adaptativos (Wilson et al., 2017). En este artículo investigamos si es preferible usar regularización L2 o regularización por weight decay para entrenar redes neuronales profundas con SGD y Adam. Mostramos que un factor importante de la mala generalización del método de gradiente adaptativo más popular, Adam, se debe al hecho de que la regularización L2 no es ni de lejos tan efectiva para él como lo es para SGD. En concreto, nuestro análisis de Adam conduce a las siguientes observaciones:

    /
    La regularización L2 y la weight decay no son idénticas. Las dos técnicas pueden hacerse equivalentes para SGD mediante una reparametrización del factor de weight decay basada en la learning rate; sin embargo, como suele pasarse por alto, este no es el caso para Adam. En particular, cuando se combina con gradientes adaptativos, la regularización L2 hace que los pesos con amplitudes históricas grandes de parámetro y/o gradiente queden regularizados menos de lo que lo estarían usando weight decay. / ; / : La regularización L2 no es eficaz en Adam. Una posible explicación de por qué Adam y otros métodos de gradiente adaptativos pueden ser superados por SGD con momentum es que las bibliotecas de aprendizaje profundo habituales solo implementan regularización L2 y no la weight decay original. Por tanto, en tareas/datasets donde el uso de regularización L2 es beneficioso para SGD (p. ej., en muchos datasets populares de clasificación de imágenes), Adam produce peores resultados que SGD con momentum (para el cual la regularización L2 se comporta como cabe esperar). / ; / : La weight decay es igualmente eficaz en SGD y en Adam. Para SGD es equivalente a la regularización L2, mientras que para Adam no lo es. / ; / : La weight decay óptima depende del número total de pasadas por batch / actualizaciones de pesos. Nuestro análisis empírico de SGD y Adam sugiere que cuanto mayor sea el runtime / número de pasadas por batch a realizar, menor será la weight decay óptima. / ; / : Adam puede beneficiarse sustancialmente de un multiplicador de learning rate planificado. El hecho de que Adam sea un algoritmo de gradiente adaptativo y, como tal, adapte la learning rate para cada parámetro, no descarta la posibilidad de mejorar sustancialmente su rendimiento usando un multiplicador global de learning rate, planificado, p. ej., mediante cosine annealing.

    La principal contribución de este artículo es mejorar la regularización en Adam desacoplando la weight decay de la actualización basada en gradiente. En un análisis exhaustivo, mostramos que Adam generaliza sustancialmente mejor con weight decay desacoplada que con regularización L2, alcanzando una mejora relativa del 15 % en error de test (véanse las Figuras 2 y 3); esto se cumple en distintos datasets de reconocimiento de imágenes (CIFAR-10 e ImageNet32x32), presupuestos de entrenamiento (entre 100 y 1800 epochs) y schedules de learning rate (fijo, drop-step y cosine annealing; véase la Figura 1). También demostramos que nuestra weight decay desacoplada hace que las configuraciones óptimas de la learning rate y del factor de weight decay sean mucho más independientes, facilitando así la optimización de hiperparámetros (véase la Figura 2).

    La motivación principal de este artículo es mejorar Adam para hacerlo competitivo respecto a SGD con momentum incluso en aquellos problemas en los que antes no lo era. Esperamos que, como resultado, los profesionales no necesiten alternar entre Adam y SGD, lo que a su vez debería reducir el problema habitual de seleccionar algoritmos de entrenamiento e hiperparámetros específicos para cada dataset/tarea.

    2 Desacoplando la weight decay de la actualización basada en gradiente

    En la weight decay descrita por Hanson y Pratt (1988), los pesos $ {\textstyle \mathbf{θ}} $ decaen exponencialmente como

    $ {\textstyle {{\mathbf{θ}}_{t + 1} = {{{({1 - \lambda})}\hspace{0pt}{\mathbf{θ}}_{t}} - {\alpha\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}}},} $ (1)

    donde $ {\textstyle \lambda} $ define la tasa de la weight decay por paso y $ {\textstyle {\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}} $ es el $ {\textstyle t} $-ésimo gradiente de batch que se multiplica por una learning rate $ {\textstyle \alpha} $. Para SGD estándar, esto es equivalente a la regularización L2 estándar:

    Proposición 1 (Weight decay = L2 reg para SGD estándar).

    El SGD estándar con base learning rate $ {\textstyle \alpha} $ ejecuta los mismos pasos sobre funciones de pérdida por batch $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $ con weight decay $ {\textstyle \lambda} $ (definida en la Ecuación 1) que ejecuta sin weight decay sobre $ {\textstyle {f_{t}^{\text{reg}}\hspace{0pt}{({\mathbf{θ}})}} = {{f_{t}\hspace{0pt}{({\mathbf{θ}})}} + {\frac{\lambda^{\prime}}{2}\hspace{0pt}\left. \parallel{\mathbf{θ}}\parallel \right._{2}^{2}}}} $, con $ {\textstyle \lambda^{\prime} = \frac{\lambda}{\alpha}} $.

    Las pruebas de este hecho bien conocido, así como de nuestras demás proposiciones, se dan en el Apéndice A.

    Debido a esta equivalencia, la regularización L2 se denomina con mucha frecuencia weight decay, incluso en bibliotecas populares de aprendizaje profundo. Sin embargo, como demostraremos más adelante en esta sección, esta equivalencia no se cumple para los métodos de gradiente adaptativos. Un hecho que a menudo se pasa por alto incluso para el caso simple de SGD es que, para que la equivalencia se mantenga, el regularizador L2 $ {\textstyle \lambda^{\prime}} $ debe fijarse a $ {\textstyle \frac{\lambda}{\alpha}} $; es decir, si existe un mejor valor global de weight decay $ {\textstyle \lambda} $, el mejor valor de $ {\textstyle \lambda^{\prime}} $ está estrechamente acoplado con la learning rate $ {\textstyle \alpha} $. Para desacoplar los efectos de estos dos hiperparámetros, abogamos por desacoplar el paso de weight decay tal como propusieron Hanson y Pratt (1988) (Ecuación 1).

    1:  given initial learning rate $ {\textstyle \alpha \in {IR}} $, momentum factor $ {\textstyle \beta_{1} \in {IR}} $, weight decay/L2 regularization factor $ {\textstyle \lambda \in {IR}} $ 2:  initialize time step $ {\textstyle t\leftarrow 0} $, parameter vector $ {\textstyle {\mathbf{θ}}_{t = 0} \in {IR}^{n}} $, first moment vector $ {\textstyle \text{m}_{t = 0}\leftarrow\text{0}} $, schedule multiplier $ {\textstyle \eta_{t = 0} \in {IR}} $ 3:  repeat 4:     $ {\textstyle t\leftarrow{t + 1}} $ 5:     $ {\textstyle {{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t - 1})}}\leftarrow{\text{SelectBatch}\hspace{0pt}{({\mathbf{θ}}_{t - 1})}}} $ $ {\textstyle \rhd} $  select batch and return the corresponding gradient 6:     $ {\textstyle \text{g}_{t}\leftarrow{{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t - 1})}}} $ $ {\textstyle + {\lambda\hspace{0pt}{\mathbf{θ}}_{t - 1}}} $ 7:     $ {\textstyle \eta_{t}\leftarrow{\text{SetScheduleMultiplier}\hspace{0pt}{(t)}}} $ $ {\textstyle \rhd} $  can be fixed, decay, be used for warm restarts 8:     $ {\textstyle \text{m}_{t}\leftarrow{{\beta_{1}\hspace{0pt}\text{m}_{t - 1}} + {\eta_{t}\hspace{0pt}\alpha\hspace{0pt}\text{g}_{t}}}} $ 9:     $ {\textstyle {\mathbf{θ}}_{t}\leftarrow{{\mathbf{θ}}_{t - 1} - \text{m}_{t}}} $ $ {\textstyle - {\eta_{t}\hspace{0pt}\lambda\hspace{0pt}{\mathbf{θ}}_{t - 1}}} $ 10:  until  stopping criterion is met 11:  return  optimized parameters $ {\textstyle {\mathbf{θ}}_{t}} $

    1:  given $ {\textstyle {\alpha = 0.001},{{\beta_{1} = 0.9},{{\beta_{2} = 0.999},{{\epsilon = 10^{- 8}},{\lambda \in {IR}}}}}} $ 2:  initialize time step $ {\textstyle t\leftarrow 0} $, parameter vector $ {\textstyle {\mathbf{θ}}_{t = 0} \in {IR}^{n}} $, first moment vector $ {\textstyle \text{m}_{t = 0}\leftarrow\text{0}} $, second moment vector $ {\textstyle \text{v}_{t = 0}\leftarrow\text{0}} $, schedule multiplier $ {\textstyle \eta_{t = 0} \in {IR}} $ 3:  repeat 4:     $ {\textstyle t\leftarrow{t + 1}} $ 5:     $ {\textstyle {{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t - 1})}}\leftarrow{\text{SelectBatch}\hspace{0pt}{({\mathbf{θ}}_{t - 1})}}} $ $ {\textstyle \rhd} $  select batch and return the corresponding gradient 6:     $ {\textstyle \text{g}_{t}\leftarrow{{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t - 1})}}} $ $ {\textstyle + {\lambda\hspace{0pt}{\mathbf{θ}}_{t - 1}}} $ 7:     $ {\textstyle \text{m}_{t}\leftarrow{{\beta_{1}\hspace{0pt}\text{m}_{t - 1}} + {{({1 - \beta_{1}})}\hspace{0pt}\text{g}_{t}}}} $ $ {\textstyle \rhd} $  here and below all operations are element-wise 8:     $ {\textstyle \text{v}_{t}\leftarrow{{\beta_{2}\hspace{0pt}\text{v}_{t - 1}} + {{({1 - \beta_{2}})}\hspace{0pt}\text{g}_{t}^{2}}}} $ 9:     $ {\textstyle {\hat{\text{m}}}_{t}\leftarrow{\text{m}_{t}/{({1 - \beta_{1}^{t}})}}} $ $ {\textstyle \rhd} $  $ {\textstyle \beta_{1}} $ is taken to the power of $ {\textstyle t} $ 10:     $ {\textstyle {\hat{\text{v}}}_{t}\leftarrow{\text{v}_{t}/{({1 - \beta_{2}^{t}})}}} $ $ {\textstyle \rhd} $  $ {\textstyle \beta_{2}} $ is taken to the power of $ {\textstyle t} $ 11:     $ {\textstyle \eta_{t}\leftarrow{\text{SetScheduleMultiplier}\hspace{0pt}{(t)}}} $ $ {\textstyle \rhd} $  can be fixed, decay, or also be used for warm restarts 12:     $ {\textstyle {\mathbf{θ}}_{t}\leftarrow{{\mathbf{θ}}_{t - 1} - {\eta_{t}\hspace{0pt}\left( {{{\alpha\hspace{0pt}{\hat{\text{m}}}_{t}}/{({\sqrt{{\hat{\text{v}}}_{t}} + \epsilon})}}\hspace{0pt}{+ {\lambda\hspace{0pt}{\mathbf{θ}}_{t - 1}}}} \right)}}} $ 13:  until  stopping criterion is met 14:  return  optimized parameters $ {\textstyle {\mathbf{θ}}_{t}} $

    Mirando primero al caso de SGD, proponemos hacer decay de los pesos simultáneamente con la actualización de $ {\textstyle {\mathbf{θ}}_{t}} $ basada en información del gradiente, en la Línea 9 del Algoritmo 1. Esto da lugar a nuestra variante propuesta de SGD con momentum usando weight decay desacoplada (SGDW). Esta simple modificación desacopla explícitamente $ {\textstyle \lambda} $ y $ {\textstyle \alpha} $ (aunque, naturalmente, puede persistir cierto acoplamiento implícito dependiente del problema, como ocurre con cualesquiera dos hiperparámetros). Para tener en cuenta una posible planificación tanto de $ {\textstyle \alpha} $ como de $ {\textstyle \lambda} $, introducimos un factor de escalado $ {\textstyle \eta_{t}} $ proporcionado por un procedimiento definido por el usuario $ {\textstyle S\hspace{0pt}e\hspace{0pt}t\hspace{0pt}S\hspace{0pt}c\hspace{0pt}h\hspace{0pt}e\hspace{0pt}d\hspace{0pt}u\hspace{0pt}l\hspace{0pt}e\hspace{0pt}M\hspace{0pt}u\hspace{0pt}l\hspace{0pt}t\hspace{0pt}i\hspace{0pt}p\hspace{0pt}l\hspace{0pt}i\hspace{0pt}e\hspace{0pt}r\hspace{0pt}{(t)}} $.

    Pasemos ahora a algoritmos de gradiente adaptativos como el popular optimizador Adam Kingma y Ba (2014), que escalan los gradientes según sus magnitudes históricas. Intuitivamente, cuando se ejecuta Adam sobre una función de pérdida $ {\textstyle f} $ más regularización L2, los pesos que tienden a tener gradientes grandes en $ {\textstyle f} $ no quedan tan regularizados como lo estarían con weight decay desacoplada, ya que el gradiente del regularizador se escala junto con el gradiente de $ {\textstyle f} $. Esto da lugar a una inequivalencia entre la regularización L2 y la weight decay desacoplada para algoritmos de gradiente adaptativos:

    Proposición 2 (Weight decay \neq L2 reg para gradientes adaptativos).

    Sea $ {\textstyle O} $ un optimizador con iteraciones $ {\textstyle {\mathbf{θ}}_{t + 1}\leftarrow{{\mathbf{θ}}_{t} - {\alpha\hspace{0pt}\mathbf{M}_{t}\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}}} $ al ejecutarse sobre una función de pérdida por batch $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $ sin weight decay, y $ {\textstyle {\mathbf{θ}}_{t + 1}\leftarrow{{{({1 - \lambda})}\hspace{0pt}{\mathbf{θ}}_{t}} - {\alpha\hspace{0pt}\mathbf{M}_{t}\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}}} $ al ejecutarse sobre $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $ con weight decay, respectivamente, con $ {\textstyle \mathbf{M}_{t} \neq {k\hspace{0pt}\mathbf{I}}} $ (donde $ {\textstyle k \in {\mathbb{R}}} $). Entonces, para $ {\textstyle O} $ no existe ningún coeficiente L2 $ {\textstyle \lambda^{\prime}} $ tal que ejecutar $ {\textstyle O} $ sobre la pérdida por batch $ {\textstyle {f_{t}^{\text{reg}}\hspace{0pt}{({\mathbf{θ}})}} = {{f_{t}\hspace{0pt}{({\mathbf{θ}})}} + {\frac{\lambda^{\prime}}{2}\hspace{0pt}\left. \parallel{\mathbf{θ}}\parallel \right._{2}^{2}}}} $ sin weight decay sea equivalente a ejecutar $ {\textstyle O} $ sobre $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $ con decay $ {\textstyle \lambda \in {\mathbb{R}}^{+}} $.

    Desacoplamos la weight decay y las actualizaciones de gradiente basadas en la pérdida en Adam tal como se muestra en la línea 12 del Algoritmo 2; esto da lugar a nuestra variante de Adam con weight decay desacoplada (AdamW).

    Haber mostrado que la regularización L2 y la regularización por weight decay difieren para los algoritmos de gradiente adaptativos plantea la cuestión de en qué difieren y cómo interpretar sus efectos. Su equivalencia para SGD estándar sigue siendo muy útil para la intuición: ambos mecanismos empujan los pesos hacia cero a la misma tasa. Sin embargo, para los algoritmos de gradiente adaptativos difieren: con regularización L2, las sumas del gradiente de la función de pérdida y del gradiente del regularizador (es decir, la norma L2 de los pesos) son adaptadas, mientras que con weight decay desacoplada solo se adaptan los gradientes de la función de pérdida (con el paso de weight decay separado del mecanismo de gradiente adaptativo). Con regularización L2, ambos tipos de gradientes son normalizados por sus magnitudes típicas (sumadas), y por lo tanto los pesos $ {\textstyle x} $ con magnitud típica de gradiente $ {\textstyle s} $ grande son regularizados en una cantidad relativa menor que el resto. En cambio, la weight decay desacoplada regulariza todos los pesos a la misma tasa $ {\textstyle \lambda} $, regularizando efectivamente los pesos $ {\textstyle x} $ con $ {\textstyle s} $ grande más de lo que lo hace la regularización L2 estándar. Lo demostramos formalmente para un caso especial sencillo de algoritmo de gradiente adaptativo con un precondicionador fijo:

    Proposición 3 (Weight decay = L_{2} reg ajustada por escala para algoritmo de gradiente adaptativo con precondicionador fijo).

    Sea $ {\textstyle O} $ un algoritmo con las mismas características que en la Proposición 2, y que utiliza una matriz precondicionadora fija $ {\textstyle \text{M}_{t} = {\text{diag}\hspace{0pt}{(\text{s})}^{- 1}}} $ (con $ {\textstyle s_{i} > 0} $ para todo $ {\textstyle i} $). Entonces, $ {\textstyle O} $ con base learning rate $ {\textstyle \alpha} $ ejecuta los mismos pasos sobre funciones de pérdida por batch $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $ con weight decay $ {\textstyle \lambda} $ que ejecuta sin weight decay sobre la pérdida por batch regularizada con ajuste por escala

    $ {\displaystyle {{f_{t}^{\text{sreg}}\hspace{0pt}{({\mathbf{θ}})}} = {{f_{t}\hspace{0pt}{({\mathbf{θ}})}} + {\frac{\lambda^{\prime}}{2\hspace{0pt}\alpha}\hspace{0pt}\left. \parallel{{\mathbf{θ}} \odot \sqrt{\text{s}}}\parallel \right._{2}^{2}}}},} $ (2)

    donde $ {\textstyle \odot} $ y $ {\textstyle \sqrt{\cdot}} $ denotan la multiplicación y la raíz cuadrada elemento a elemento, respectivamente, y $ {\textstyle \lambda^{\prime} = \frac{\lambda}{\alpha}} $.

    Notamos que esta proposición no se aplica directamente a algoritmos de gradiente adaptativos prácticos, ya que estos cambian la matriz precondicionadora en cada paso. No obstante, puede aún proporcionar intuición sobre la función de pérdida equivalente que se está optimizando en cada paso: los parámetros $ {\textstyle \theta_{i}} $ con un precondicionador inverso grande $ {\textstyle s_{i}} $ (que en la práctica estaría causado por gradientes históricamente grandes en la dimensión $ {\textstyle i} $) son regularizados relativamente más de lo que lo serían con regularización L2; específicamente, la regularización es proporcional a $ {\textstyle \sqrt{s_{i}}} $.

    3 Justificación de la weight decay desacoplada mediante una visión de los métodos de gradiente adaptativos como filtrado bayesiano

    Discutimos ahora una justificación de la weight decay desacoplada en el marco del filtrado bayesiano para una teoría unificada de los algoritmos de gradiente adaptativos debida a Aitchison (2018). Tras publicar una versión preliminar de nuestro artículo actual en arXiv, Aitchison observó que su teoría «nos proporciona un marco teórico en el que podemos entender la superioridad de esta weight decay sobre la regularización $ {\textstyle L_{2}} $, porque es la weight decay, y no la regularización $ {\textstyle L_{2}} $, la que emerge a través de la aplicación directa del filtrado bayesiano» (Aitchison, 2018). Si bien todo el mérito de esta teoría corresponde a Aitchison, la resumimos aquí para arrojar algo de luz sobre por qué la weight decay puede ser preferible a la regularización $ {\textstyle L_{2}} $.

    Aitchison (2018) ve la optimización estocástica de $ {\textstyle n} $ parámetros $ {\textstyle \theta_{1},\ldots,\theta_{n}} $ como un problema de filtrado bayesiano cuyo objetivo es inferir una distribución sobre los valores óptimos de cada uno de los parámetros $ {\textstyle \theta_{i}} $, dados los valores actuales del resto de los parámetros $ {\textstyle {\mathbf{θ}}_{- i}\hspace{0pt}{(t)}} $ en el paso temporal $ {\textstyle t} $. Cuando los demás parámetros no cambian se trata de un problema de optimización, pero cuando sí cambian se convierte en uno de «seguimiento» del optimizador mediante filtrado bayesiano del siguiente modo. Se tiene una distribución de probabilidad $ {\textstyle P\hspace{0pt}{({{\mathbf{θ}}_{t} \mid {\mathbf{y}}_{\mathbf{1}:{\mathbf{t}}}})}} $ del optimizador en el paso temporal $ {\textstyle t} $ que tiene en cuenta los datos $ {\textstyle {\mathbf{y}}_{\mathbf{1}:{\mathbf{t}}}} $ de los primeros $ {\textstyle t} $ mini batches, una prior de transición de estados $ {\textstyle P\hspace{0pt}{({{\mathbf{θ}}_{t + 1} \mid {\mathbf{θ}}_{t}})}} $ que refleja un cambio (pequeño) e independiente de los datos en esta distribución de un paso al siguiente, y una verosimilitud $ {\textstyle P\hspace{0pt}{({{\mathbf{y}}_{t + 1} \mid {\mathbf{θ}}_{t + 1}})}} $ derivada del mini batch en el paso $ {\textstyle t + 1} $. La distribución posterior $ {\textstyle P\hspace{0pt}{({{\mathbf{θ}}_{t + 1} \mid {\mathbf{y}}_{\mathbf{1}:{{\mathbf{t}} + \mathbf{1}}}})}} $ del optimizador en el paso $ {\textstyle t + 1} $ puede entonces calcularse (como es habitual en filtrado bayesiano) marginalizando sobre $ {\textstyle {\mathbf{θ}}_{t}} $ para obtener las predicciones a un paso $ {\textstyle P\hspace{0pt}{({{\mathbf{θ}}_{t + 1} \mid {\mathbf{y}}_{\mathbf{1}:{\mathbf{t}}}})}} $ y aplicando luego la regla de Bayes para incorporar la verosimilitud $ {\textstyle P\hspace{0pt}{({{\mathbf{y}}_{t + 1} \mid {\mathbf{θ}}_{t + 1}})}} $. Aitchison (2018) supone una distribución de transición de estados gaussiana $ {\textstyle P\hspace{0pt}{({{\mathbf{θ}}_{t + 1} \mid {\mathbf{θ}}_{t}})}} $ y una verosimilitud aproximadamente conjugada $ {\textstyle P\hspace{0pt}{({{\mathbf{y}}_{t + 1} \mid {\mathbf{θ}}_{t + 1}})}} $, lo que conduce a la siguiente actualización en forma cerrada de la media de la distribución de filtrado:

    $ {\displaystyle {{\mathbf{μ}}_{p\hspace{0pt}o\hspace{0pt}s\hspace{0pt}t} = {{\mathbf{μ}}_{p\hspace{0pt}r\hspace{0pt}i\hspace{0pt}o\hspace{0pt}r} + {\mathbf{\Sigma}_{p\hspace{0pt}o\hspace{0pt}s\hspace{0pt}t} \times {\mathbf{g}}}}},} $ (3)

    donde $ {\textstyle \mathbf{g}} $ es el gradiente de la log-verosimilitud del mini batch en el tiempo $ {\textstyle t} $. Este resultado implica un precondicionador de los gradientes dado por la incertidumbre posterior $ {\textstyle \mathbf{\Sigma}_{p\hspace{0pt}o\hspace{0pt}s\hspace{0pt}t}} $ de la distribución de filtrado: las actualizaciones son mayores para los parámetros sobre los que tenemos más incertidumbre y menores para aquellos sobre los que tenemos más certeza. Aitchison (2018) muestra a continuación que métodos de gradiente adaptativos populares, como Adam y RMSprop, así como métodos factorizados con Kronecker, son casos particulares de este marco.

    La weight decay desacoplada encaja muy naturalmente en este marco unificado como parte de la distribución de transición de estados: Aitchison (2018) supone un cambio lento del optimizador según la siguiente gaussiana:

    $ {\displaystyle {{P\hspace{0pt}{({{\mathbf{θ}}_{t + 1} \mid {\mathbf{θ}}_{t}})}} = {\mathcal{N}\hspace{0pt}{({{({{\mathbf{I}} - {\mathbf{A}}})}\hspace{0pt}{\mathbf{θ}}_{t}},{\mathbf{Q}})}}},} $ (4)

    donde $ {\textstyle \mathbf{Q}} $ es la covarianza de las perturbaciones gaussianas de los pesos, y $ {\textstyle \mathbf{A}} $ es un regularizador para evitar que los valores crezcan ilimitadamente con el tiempo. Cuando se instancia como $ {\textstyle {\mathbf{A}} = {\lambda \times {\mathbf{I}}}} $, este regularizador $ {\textstyle \mathbf{A}} $ juega exactamente el papel de la weight decay desacoplada descrita en la Ecuación 1, ya que ello conduce a multiplicar la estimación actual de la media $ {\textstyle {\mathbf{θ}}_{t}} $ por $ {\textstyle ({1 - \lambda})} $ en cada paso. Notablemente, esta regularización también se aplica directamente a la prior y no depende de la incertidumbre en cada uno de los parámetros (lo cual sería necesario para la regularización $ {\textstyle L_{2}} $).

    Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption

    4 Validación experimental

    Evaluamos ahora el rendimiento de la weight decay desacoplada bajo distintos presupuestos de entrenamiento y schedules de learning rate. Nuestro montaje experimental sigue el de Gastaldi (2017), quien propuso, además de la regularización L2, aplicar la nueva regularización Shake-Shake a una DNN residual de 3 ramas que permitió alcanzar nuevos resultados estado del arte del 2,86 % en el dataset CIFAR-10 (Krizhevsky, 2009). Usamos el mismo modelo / código fuente basado en fb.resnet.torch 111https://github.com/xgastaldi/shake-shake. Siempre utilizamos un batch size de 128 y aplicamos el procedimiento habitual de data augmentation para los datasets CIFAR. Las redes base son una ResNet 26 2x64d (es decir, la red tiene una profundidad de 26, 2 ramas residuales y el primer bloque residual tiene una anchura de 64) y una ResNet 26 2x96d con 11,6M y 25,6M parámetros respectivamente. Para una descripción detallada de la red y del método Shake-Shake, remitimos al lector interesado a Gastaldi (2017). También realizamos experimentos sobre el dataset ImageNet32x32 (Chrabaszcz et al., 2017), una versión submuestreada del dataset original ImageNet con 1,2 millones de imágenes de 32$ {\textstyle \times} $32 píxeles.

    4.1 Evaluación de la weight decay desacoplada con distintos schedules de learning rate

    En nuestro primer experimento, comparamos Adam con regularización $ {\textstyle L_{2}} $ con Adam con weight decay desacoplada (AdamW), usando tres schedules de learning rate distintos: una learning rate fija, un schedule drop-step y un schedule de cosine annealing (Loshchilov & Hutter, 2016). Dado que Adam ya adapta sus learning rates por parámetro, no es tan común usar un schedule de multiplicador de learning rate con él como con SGD, pero, como muestran nuestros resultados, tales schedules pueden mejorar sustancialmente el rendimiento de Adam, y abogamos por no pasar por alto su uso en algoritmos de gradiente adaptativos.

    Para cada schedule de learning rate y cada variante de weight decay, entrenamos una ResNet 2x64d durante 100 epochs, usando distintas configuraciones de la learning rate inicial $ {\textstyle \alpha} $ y del factor de weight decay $ {\textstyle \lambda} $. La Figura 1 muestra que la weight decay desacoplada supera a la regularización $ {\textstyle L_{2}} $ para todos los schedules de learning rate, con diferencias mayores para schedules mejores. También observamos que la weight decay desacoplada conduce a un espacio de búsqueda de hiperparámetros más separable, especialmente cuando se aplica un schedule de learning rate como step-drop o cosine annealing. La figura también muestra que cosine annealing supera claramente al resto de schedules de learning rate; por ello usamos cosine annealing en el resto de los experimentos.

    Refer to caption           Refer to caption

    Refer to caption           Refer to caption

    Refer to caption               Refer to caption

    Refer to caption               Refer to caption

    4.2 Desacoplando los parámetros de weight decay y learning rate inicial

    Para verificar nuestra hipótesis sobre el acoplamiento de $ {\textstyle \alpha} $ y $ {\textstyle \lambda} $, en la Figura 2 comparamos el rendimiento de la regularización L2 frente a la weight decay desacoplada en SGD (SGD vs. SGDW, fila superior) y en Adam (Adam vs. AdamW, fila inferior). En SGD (Figura 2, arriba a la izquierda), la regularización L2 no está desacoplada de la learning rate (la forma habitual descrita en el Algoritmo 1), y la figura muestra claramente que la cuenca de los mejores hiperparámetros (representada por color y, mediante círculos negros, las 10 mejores configuraciones) no está alineada con el eje x ni con el y, sino que se encuentra sobre la diagonal. Esto sugiere que los dos hiperparámetros son interdependientes y deben cambiarse simultáneamente, mientras que cambiar solo uno de ellos puede empeorar sustancialmente los resultados. Considérese, p. ej., la configuración del círculo negro arriba a la izquierda ($ {\textstyle \alpha = {1/2}} $, $ {\textstyle \lambda = {{1/8} \ast 0.001}} $); cambiar solo $ {\textstyle \alpha} $ o solo $ {\textstyle \lambda} $ empeoraría los resultados, mientras que cambiar ambos podría aún producir mejoras claras. Notamos que este acoplamiento de la learning rate inicial y el factor de regularización L2 podría haber contribuido a la fama de SGD de ser muy sensible a sus configuraciones de hiperparámetros.

    Por el contrario, los resultados para SGD con weight decay desacoplada (SGDW) en la Figura 2 (arriba a la derecha) muestran que la weight decay y la learning rate inicial están desacopladas. El enfoque propuesto hace que los dos hiperparámetros sean más separables: incluso si la learning rate todavía no está bien ajustada (p. ej., considérese el valor de 1/1024 en la Figura 2, arriba a la derecha), dejarla fija y optimizar solo el factor de weight decay produciría un buen valor (de 1/4*0.001). No es así para SGD con regularización L2 (véase la Figura 2, arriba a la izquierda).

    Los resultados de Adam con regularización L2 aparecen en la Figura 2 (abajo a la izquierda). Las mejores configuraciones de hiperparámetros de Adam se comportaron claramente peor que las mejores de SGD (compárese con la Figura 2, arriba a la izquierda). Aunque ambos métodos usaron regularización L2, Adam no se benefició en absoluto de ella: sus mejores resultados con factores de regularización L2 distintos de cero fueron comparables a los mejores obtenidos sin regularización L2, es decir, cuando $ {\textstyle \lambda = 0} $. De forma similar al SGD original, la forma del paisaje de hiperparámetros sugiere que los dos hiperparámetros están acoplados.

    En contraste, los resultados de nuestra nueva variante de Adam con weight decay desacoplada (AdamW) en la Figura 2 (abajo a la derecha) muestran que AdamW desacopla en gran medida la weight decay y la learning rate. Los resultados para las mejores configuraciones de hiperparámetros fueron sustancialmente mejores que los mejores de Adam con regularización L2 y rivalizaron con los de SGD y SGDW.

    En resumen, los resultados de la Figura 2 respaldan nuestra hipótesis de que los hiperparámetros de weight decay y learning rate pueden desacoplarse, y que esto, a su vez, simplifica el problema del ajuste de hiperparámetros en SGD y mejora el rendimiento de Adam para hacerlo competitivo respecto a SGD con momentum.

    4.3 Mejor generalización de AdamW

    Si bien el experimento anterior sugería que la cuenca de hiperparámetros óptimos de AdamW es más amplia y profunda que la de Adam, a continuación investigamos los resultados para corridas mucho más largas, de 1800 epochs, para comparar las capacidades de generalización de AdamW y Adam.

    Fijamos la learning rate inicial a 0,001, que representa tanto la learning rate por defecto para Adam como la que mostró resultados razonablemente buenos en nuestros experimentos. La Figura 3 muestra los resultados para 12 configuraciones de la regularización L2 de Adam y 7 configuraciones de la weight decay normalizada de AdamW (la weight decay normalizada representa un reescalado formalmente definido en el Apéndice B.1; equivale a un factor multiplicativo que depende del número de pasadas por batch). Curiosamente, mientras que las dinámicas de las curvas de aprendizaje de Adam y AdamW solían coincidir durante la primera mitad del entrenamiento, AdamW conducía a menudo a menor training loss y menores errores de test (véase la Figura 3 arriba a la izquierda y arriba a la derecha, respectivamente). Es importante notar que el uso de weight decay L2 en Adam no produjo resultados tan buenos como la weight decay desacoplada en AdamW (véase también la Figura 3, abajo a la izquierda). A continuación investigamos si los mejores resultados de AdamW se debían solo a una mejor convergencia o a una mejor generalización. Los resultados de la Figura 3 (abajo a la derecha) para las mejores configuraciones de Adam y AdamW sugieren que AdamW no solo produjo mejor training loss, sino también mejor rendimiento de generalización para valores similares de training loss. Los resultados en ImageNet32x32 (véase SuppFigure 4 en el Apéndice) llevan a la misma conclusión de un rendimiento de generalización sustancialmente mejorado.

    Refer to caption               Refer to caption

    4.4 AdamWR con warm restarts para mejor rendimiento anytime

    Para mejorar el rendimiento anytime de SGDW y AdamW, los extendimos con los warm restarts que introdujimos en Loshchilov & Hutter (2016), obteniendo SGDWR y AdamWR, respectivamente (véase la Sección B.2 del Apéndice). Como muestra la Figura 4, AdamWR aceleró enormemente AdamW en CIFAR-10 e ImageNet32x32, hasta en un factor de 10 (véanse los resultados en el primer restart). Para la learning rate por defecto de 0,001, AdamW alcanzó una mejora relativa del 15 % en error de test respecto a Adam tanto en CIFAR-10 (véase también la SuppFigure 5) como en ImageNet32x32 (véase también la SuppFigure 6).

    AdamWR alcanzó los mismos resultados mejorados pero con un rendimiento anytime mucho mejor. Estas mejoras cerraron la mayor parte de la brecha entre Adam y SGDWR en CIFAR-10 y produjeron un rendimiento comparable en ImageNet32x32.

    4.5 Uso de AdamW en otros datasets y arquitecturas

    Otros grupos de investigación ya han aplicado con éxito AdamW en trabajos citables. Por ejemplo, Wang et al. (2018) usaron AdamW para entrenar una nueva arquitectura de detección de rostros sobre el dataset estándar WIDER FACE (Yang et al., 2016), obteniendo predicciones casi 10 veces más rápidas que los algoritmos previos del estado del arte y alcanzando un rendimiento comparable. Völker et al. (2018) emplearon AdamW con cosine annealing para entrenar redes neuronales convolucionales que clasifican y caracterizan señales cerebrales relacionadas con errores medidas a partir de registros de electroencefalografía intracraneal (EEG). Si bien su artículo no proporciona una comparación con Adam, amablemente nos facilitaron una comparación directa entre ambos sobre su mejor arquitectura de red específica del problema, Deep4Net, y una variante de ResNet. AdamW con la misma configuración de hiperparámetros que Adam produjo una exactitud de test set superior en Deep4Net (73,68 % frente a 71,37 %) y una exactitud de test set estadísticamente significativamente superior en ResNet (72,04 % frente a 61,34 %). Radford et al. (2018) emplearon AdamW para entrenar arquitecturas Transformer (Vaswani et al., 2017) obteniendo nuevos resultados estado del arte en una amplia gama de benchmarks de comprensión del lenguaje natural. Zhang et al. (2018) compararon regularización L2 frente a weight decay para SGD, Adam y el optimizador Kronecker-Factored Approximate Curvature (K-FAC) (Martens & Grosse, 2015) en los datasets CIFAR con arquitecturas ResNet y VGG, reportando que la weight decay desacoplada superó consistentemente a la regularización L2 en los casos en que difieren.

    5 Conclusión y trabajo futuro

    Siguiendo las sugerencias de que los métodos de gradiente adaptativos como Adam podrían conducir a peor generalización que SGD con momentum (Wilson et al., 2017), identificamos y expusimos la inequivalencia entre la regularización L2 y la weight decay para Adam. Mostramos empíricamente que nuestra versión de Adam con weight decay desacoplada produce un rendimiento de generalización sustancialmente mejor que la implementación habitual de Adam con regularización L2. También propusimos usar warm restarts para Adam con el fin de mejorar su rendimiento anytime.

    Nuestros resultados obtenidos en datasets de clasificación de imágenes deben verificarse en una gama más amplia de tareas, especialmente aquellas en las que se espera que el uso de regularización sea importante. Sería interesante integrar nuestros hallazgos sobre weight decay con otros métodos que tratan de mejorar Adam, p. ej., Adam con preservación de dirección normalizada (Zhang et al., 2017). Aunque centramos nuestro análisis experimental en Adam, creemos que resultados similares también se cumplen para otros métodos de gradiente adaptativos, como AdaGrad (Duchi et al., 2011) y AMSGrad (Reddi et al., 2018).

    6 Agradecimientos

    Agradecemos a Patryk Chrabaszcz su ayuda con la ejecución de los experimentos en ImageNet32x32; a Matthias Feurer y Robin Schirrmeister por proporcionar valiosos comentarios sobre este artículo en varias iteraciones; y a Martin Völker, Robin Schirrmeister y Tonio Ball por proporcionarnos una comparación de AdamW y Adam sobre sus datos de EEG. También agradecemos a los siguientes miembros de la comunidad de aprendizaje profundo por implementar la weight decay desacoplada en distintas bibliotecas de aprendizaje profundo:

    Este trabajo fue financiado por el European Research Council (ERC) en el marco del programa de investigación e innovación Horizonte 2020 de la Unión Europea bajo la subvención n.º 716721, por la Deutsche Forschungsgemeinschaft (DFG) bajo el Cluster of Excellence BrainLinksBrainTools (subvención n.º EXC 1086) y mediante la subvención n.º INST 37/935-1 FUGG, y por el estado alemán de Baden-Württemberg a través de bwHPC.

    Referencias

    • Aitchison (2018) Laurence Aitchison. A unified theory of adaptive stochastic gradient descent as Bayesian filtering. arXiv:1507.02030, 2018. / * Chrabaszcz et al. (2017) Patryk Chrabaszcz, Ilya Loshchilov, and Frank Hutter. A downsampled variant of ImageNet as an alternative to the CIFAR datasets. arXiv:1707.08819, 2017. / * Cubuk et al. (2018) Ekin D Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V Le. Autoaugment: Learning augmentation policies from data. arXiv preprint arXiv:1805.09501, 2018. / * Dinh et al. (2017) Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. arXiv:1703.04933, 2017. / * Duchi et al. (2011) John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learning and stochastic optimization. The Journal of Machine Learning Research, 12:2121–2159, 2011. / * Gastaldi (2017) Xavier Gastaldi. Shake-Shake regularization. arXiv preprint arXiv:1705.07485, 2017. / * Hanson & Pratt (1988) Stephen José Hanson and Lorien Y Pratt. Comparing biases for minimal network construction with back-propagation. En Proceedings of the 1st International Conference on Neural Information Processing Systems, pp. 177–185, 1988. / * Huang et al. (2017) Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E Hopcroft, and Kilian Q Weinberger. Snapshot ensembles: Train 1, get m for free. arXiv:1704.00109, 2017. / * Keskar et al. (2016) Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv:1609.04836, 2016. / * Kingma & Ba (2014) Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv:1412.6980, 2014. / * Krizhevsky (2009) Alex Krizhevsky. Learning multiple layers of features from tiny images. 2009. / * Li et al. (2017) Hao Li, Zheng Xu, Gavin Taylor, and Tom Goldstein. Visualizing the loss landscape of neural nets. arXiv preprint arXiv:1712.09913, 2017. / * Loshchilov & Hutter (2016) Ilya Loshchilov and Frank Hutter. SGDR: stochastic gradient descent with warm restarts. arXiv:1608.03983, 2016. / * Martens & Grosse (2015) James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. En International conference on machine learning, pp. 2408–2417, 2015. / * Radford et al. (2015) Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv:1511.06434, 2015. / * Radford et al. (2018) Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving language understanding by generative pre-training. URL https://s3-us-west-2. amazonaws. com/openai-assets/research-covers/language-unsupervised/language_ understanding_paper. pdf, 2018. / * Reddi et al. (2018) Sashank J. Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of adam and beyond. International Conference on Learning Representations, 2018. / * Smith (2016) Leslie N Smith. Cyclical learning rates for training neural networks. arXiv:1506.01186v3, 2016. / * Tieleman & Hinton (2012) Tijmen Tieleman and Geoffrey Hinton. Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2):26–31, 2012. / * Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. En Advances in Neural Information Processing Systems, pp. 5998–6008, 2017. / * Völker et al. (2018) Martin Völker, Jiří Hammer, Robin T Schirrmeister, Joos Behncke, Lukas DJ Fiederer, Andreas Schulze-Bonhage, Petr Marusič, Wolfram Burgard, and Tonio Ball. Intracranial error detection via deep learning. arXiv preprint arXiv:1805.01667, 2018. / * Wang et al. (2018) Jianfeng Wang, Ye Yuan, Gang Yu, and Sun Jian. Sface: An efficient network for face detection in large scale variations. arXiv preprint arXiv:1804.06559, 2018. / * Wilson et al. (2017) Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning. arXiv:1705.08292, 2017. / * Xu et al. (2015) Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhudinov, Rich Zemel, and Yoshua Bengio. Show, attend and tell: Neural image caption generation with visual attention. En International Conference on Machine Learning, pp. 2048–2057, 2015. / * Yang et al. (2016) Shuo Yang, Ping Luo, Chen-Change Loy, and Xiaoou Tang. Wider face: A face detection benchmark. En Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5525–5533, 2016. / * Zhang et al. (2018) Guodong Zhang, Chaoqi Wang, Bowen Xu, and Roger Grosse. Three mechanisms of weight decay regularization. arXiv preprint arXiv:1810.12281, 2018. / * Zhang et al. (2017) Zijun Zhang, Lin Ma, Zongpeng Li, and Chuan Wu. Normalized direction-preserving adam. arXiv:1709.04546, 2017. / * Zoph et al. (2017) Barret Zoph, Vijay Vasudevan, Jonathon Shlens, and Quoc V. Le. Learning transferable architectures for scalable image recognition. En arXiv:1707.07012 [cs.CV], 2017.

    Appendix

    Apéndice A Análisis formal de Weight Decay vs Regularización L2

    Demostración de la Proposición 1
    / La demostración de este hecho bien conocido es directa. SGD sin weight decay tiene las siguientes iteraciones sobre $ {\textstyle {f_{t}^{\text{reg}}\hspace{0pt}{({\mathbf{θ}})}} = {{f_{t}\hspace{0pt}{({\mathbf{θ}})}} + {\frac{\lambda^{\prime}}{2}\hspace{0pt}\left. \parallel{\mathbf{θ}}\parallel \right._{2}^{2}}}} $:

    $ {\displaystyle {{\mathbf{θ}}_{t + 1}\leftarrow{{\mathbf{θ}}_{t} - {\alpha\hspace{0pt}{\nabla f_{t}^{\text{reg}}}\hspace{0pt}{({\mathbf{θ}}_{t})}}} = {{\mathbf{θ}}_{t} - {\alpha\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}} - {\alpha\hspace{0pt}\lambda^{\prime}\hspace{0pt}{\mathbf{θ}}_{t}}}}.} $ (5)

    SGD con weight decay tiene las siguientes iteraciones sobre $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $:

    $ {\displaystyle {{\mathbf{θ}}_{t + 1}\leftarrow{{{({1 - \lambda})}\hspace{0pt}{\mathbf{θ}}_{t}} - {\alpha\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}}}.} $ (6)

    Estas iteraciones son idénticas, ya que $ {\textstyle \lambda^{\prime} = \frac{\lambda}{\alpha}} $. ∎

    Demostración de la Proposición 2
    / De forma análoga a la demostración de la Proposición 1, las iteraciones de $ {\textstyle O} $ sin weight decay sobre $ {\textstyle {f_{t}^{\text{reg}}\hspace{0pt}{({\mathbf{θ}})}} = {{f_{t}\hspace{0pt}{({\mathbf{θ}})}} + {\frac{1}{2}\hspace{0pt}\lambda^{\prime}\hspace{0pt}\left. \parallel{\mathbf{θ}}\parallel \right._{2}^{2}}}} $ y de $ {\textstyle O} $ con weight decay $ {\textstyle \lambda} $ sobre $ {\textstyle f_{t}} $ son, respectivamente:

    $ {\textstyle {\mathbf{θ}}_{t + 1}} $ $ {\textstyle \leftarrow} $ $ {\textstyle {{\mathbf{θ}}_{t} - {\alpha\hspace{0pt}\lambda^{\prime}\hspace{0pt}\mathbf{M}_{t}\hspace{0pt}{\mathbf{θ}}_{t}} - {\alpha\hspace{0pt}\mathbf{M}_{t}\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}}.} $ (7)
    $ {\textstyle {\mathbf{θ}}_{t + 1}} $ $ {\textstyle \leftarrow} $ $ {\textstyle {{{({1 - \lambda})}\hspace{0pt}{\mathbf{θ}}_{t}} - {\alpha\hspace{0pt}\mathbf{M}_{t}\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}}.} $ (8)

    La igualdad de estas iteraciones para todo $ {\textstyle {\mathbf{θ}}_{t}} $ implicaría $ {\textstyle {\lambda\hspace{0pt}{\mathbf{θ}}_{t}} = {\alpha\hspace{0pt}\lambda^{\prime}\hspace{0pt}\mathbf{M}_{t}\hspace{0pt}{\mathbf{θ}}_{t}}} $. Esto solo puede cumplirse para todo $ {\textstyle {\mathbf{θ}}_{t}} $ si $ {\textstyle \mathbf{M}_{t} = {k\hspace{0pt}\mathbf{I}}} $, con $ {\textstyle k \in {\mathbb{R}}} $, lo cual no es el caso de $ {\textstyle O} $. Por tanto, no existe ningún regularizador L2 $ {\textstyle \lambda^{\prime}\hspace{0pt}\left. \parallel{\mathbf{θ}}\parallel \right._{2}^{2}} $ que haga las iteraciones equivalentes. ∎

    Demostración de la Proposición 3
    / $ {\textstyle O} $ sin weight decay tiene las siguientes iteraciones sobre $ {\textstyle {f_{t}^{\text{sreg}}\hspace{0pt}{({\mathbf{θ}})}} = {{f_{t}\hspace{0pt}{({\mathbf{θ}})}} + {\frac{\lambda^{\prime}}{2}\hspace{0pt}\left. \parallel{{\mathbf{θ}} \odot \sqrt{\text{s}}}\parallel \right._{2}^{2}}}} $:

    $ {\textstyle {\mathbf{θ}}_{t + 1}} $ $ {\textstyle \leftarrow} $ $ {\textstyle {\mathbf{θ}}_{t} - {{\alpha\hspace{0pt}{\nabla f_{t}^{\text{sreg}}}\hspace{0pt}{({\mathbf{θ}}_{t})}}/\text{s}}} $ (9)
    $ {\textstyle =} $ $ {\textstyle {\mathbf{θ}}_{t} - {{\alpha\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}/\text{s}} - {{{\alpha\hspace{0pt}\lambda^{\prime}\hspace{0pt}{\mathbf{θ}}_{t}} \odot \text{s}}/\text{s}}} $ (10)
    $ {\textstyle =} $ $ {\textstyle {{\mathbf{θ}}_{t} - {{\alpha\hspace{0pt}{\nabla f_{t}}\hspace{0pt}{({\mathbf{θ}}_{t})}}/\text{s}} - {\alpha\hspace{0pt}\lambda^{\prime}\hspace{0pt}{\mathbf{θ}}_{t}}},} $ (11)

    donde la división por s es elemento a elemento. $ {\textstyle O} $ con weight decay tiene las siguientes iteraciones sobre $ {\textstyle f_{t}\hspace{0pt}{({\mathbf{θ}})}} $:

    $ {\textstyle {\mathbf{θ}}_{t + 1}} $ $ {\textstyle \leftarrow} $ $ {\textstyle {{({1 - \lambda})}\hspace{0pt}{\mathbf{θ}}_{t}} - {{\alpha\hspace{0pt}{\nabla f}\hspace{0pt}{({\mathbf{θ}}_{t})}}/\text{s}}} $ (12)
    $ {\textstyle =} $ $ {\textstyle {{\mathbf{θ}}_{t} - {{\alpha\hspace{0pt}{\nabla f}\hspace{0pt}{({\mathbf{θ}}_{t})}}/\text{s}} - {\lambda\hspace{0pt}{\mathbf{θ}}_{t}}},} $ (13)

    Estas iteraciones son idénticas, ya que $ {\textstyle \lambda^{\prime} = \frac{\lambda}{\alpha}} $. ∎

    Apéndice B Mejoras prácticas adicionales de Adam

    Habiendo discutido la weight decay desacoplada para mejorar la generalización de Adam, en esta sección introducimos dos componentes adicionales para mejorar el rendimiento de Adam en la práctica.

    B.1 Weight decay normalizada

    Nuestros experimentos preliminares mostraron que distintos factores de weight decay son óptimos para distintos presupuestos computacionales (definidos en términos del número de pasadas por batch). De forma relacionada, Li et al. (2017) demostraron que un batch size más pequeño (para el mismo número total de epochs) hace que el efecto de contracción de la weight decay sea más pronunciado. Aquí proponemos reducir esta dependencia normalizando los valores de weight decay. En concreto, sustituimos el hiperparámetro $ {\textstyle \lambda} $ por un nuevo hiperparámetro de weight decay normalizada (más robusto) $ {\textstyle \lambda_{n\hspace{0pt}o\hspace{0pt}r\hspace{0pt}m}} $, y lo usamos para fijar $ {\textstyle \lambda} $ como $ {\textstyle \lambda = {\lambda_{n\hspace{0pt}o\hspace{0pt}r\hspace{0pt}m}\hspace{0pt}\sqrt{\frac{b}{B\hspace{0pt}T}}}} $, donde $ {\textstyle b} $ es el batch size, $ {\textstyle B} $ es el número total de puntos de entrenamiento y $ {\textstyle T} $ es el número total de epochs.222En el contexto de nuestra variante AdamWR discutida en la Sección B.2, $ {\textstyle T} $ es el número total de epochs en el restart actual. Así, $ {\textstyle \lambda_{n\hspace{0pt}o\hspace{0pt}r\hspace{0pt}m}} $ puede interpretarse como la weight decay que se utilizaría si solo se permitiera una pasada por batch. Subrayamos que nuestra elección de normalización es solo una posibilidad informada por unos pocos experimentos; una conclusión más duradera que extraemos es que usar alguna normalización puede mejorar sustancialmente los resultados.

    B.2 Adam con cosine annealing y warm restarts

    Aplicamos ahora cosine annealing y warm restarts a Adam, siguiendo nuestro trabajo reciente (Loshchilov & Hutter, 2016). Allí propusimos el descenso de gradiente estocástico con warm restarts (SGDR) para mejorar el rendimiento anytime de SGD enfriando rápidamente la learning rate según un schedule cosine y aumentándola periódicamente. SGDR ha sido adoptado con éxito para conducir a nuevos resultados estado del arte en benchmarks populares de clasificación de imágenes (Huang et al., 2017; Gastaldi, 2017; Zoph et al., 2017), por lo que ya intentamos extenderlo a Adam poco después de proponerlo. Sin embargo, aunque nuestra versión inicial de Adam con warm restarts tenía mejor rendimiento anytime que Adam, no era competitiva con SGD con warm restarts, precisamente porque la regularización L2 no funcionaba tan bien como en SGD. Ahora, una vez resuelto este problema mediante la regularización original por weight decay (Sección 2) y habiendo introducido también la weight decay normalizada (Sección B.1), nuestro trabajo original sobre cosine annealing y warm restarts se traslada directamente a Adam.

    Para mantener la presentación autocontenida, describimos brevemente cómo SGDR planifica el cambio de la learning rate efectiva con el fin de acelerar el entrenamiento de DNNs. Aquí, desacoplamos la learning rate inicial $ {\textstyle \alpha} $ y su multiplicador $ {\textstyle \eta_{t}} $ usado para obtener la learning rate real en la iteración $ {\textstyle t} $ (véase, p. ej., la línea 8 del Algoritmo 1). En SGDR simulamos una nueva ejecución/restart con warm-start de SGD una vez realizadas $ {\textstyle T_{i}} $ epochs, donde $ {\textstyle i} $ es el índice de la ejecución. Es importante destacar que los reinicios no se realizan desde cero, sino que se emulan aumentando $ {\textstyle \eta_{t}} $ mientras se usa el valor antiguo de $ {\textstyle {\mathbf{θ}}_{t}} $ como solución inicial. La cantidad en la que se incrementa $ {\textstyle \eta_{t}} $ controla en qué medida se utiliza la información previamente adquirida (p. ej., el momentum). Dentro de la $ {\textstyle i} $-ésima ejecución, el valor de $ {\textstyle \eta_{t}} $ decae de acuerdo con una learning rate por cosine annealing (Loshchilov & Hutter, 2016) para cada batch como sigue:

    $ {\textstyle {\eta_{t} = {\eta_{m\hspace{0pt}i\hspace{0pt}n}^{(i)} + {0.5\hspace{0pt}{({\eta_{m\hspace{0pt}a\hspace{0pt}x}^{(i)} - \eta_{m\hspace{0pt}i\hspace{0pt}n}^{(i)}})}\hspace{0pt}{({1 + {\cos{({{\pi\hspace{0pt}T_{c\hspace{0pt}u\hspace{0pt}r}}/T_{i}})}}})}}}},} $ (14)

    donde $ {\textstyle \eta_{m\hspace{0pt}i\hspace{0pt}n}^{(i)}} $ y $ {\textstyle \eta_{m\hspace{0pt}a\hspace{0pt}x}^{(i)}} $ son los rangos del multiplicador y $ {\textstyle T_{c\hspace{0pt}u\hspace{0pt}r}} $ contabiliza cuántas epochs se han realizado desde el último restart. $ {\textstyle T_{c\hspace{0pt}u\hspace{0pt}r}} $ se actualiza en cada iteración de batch $ {\textstyle t} $ y, por tanto, no está restringido a valores enteros. Ajustar (p. ej., disminuir) $ {\textstyle \eta_{m\hspace{0pt}i\hspace{0pt}n}^{(i)}} $ y $ {\textstyle \eta_{m\hspace{0pt}a\hspace{0pt}x}^{(i)}} $ en cada $ {\textstyle i} $-ésimo restart (véase también Smith (2016)) podría mejorar potencialmente el rendimiento, pero no consideramos esa opción aquí porque introduciría hiperparámetros adicionales. Para $ {\textstyle \eta_{m\hspace{0pt}a\hspace{0pt}x}^{(i)} = 1} $ y $ {\textstyle \eta_{m\hspace{0pt}i\hspace{0pt}n}^{(i)} = 0} $, la Ec. (14) puede simplificarse a

    $ {\textstyle {\eta_{t} = {0.5 + {0.5\hspace{0pt}{\cos{({{\pi\hspace{0pt}T_{c\hspace{0pt}u\hspace{0pt}r}}/T_{i}})}}}}}.} $ (15)

    Para conseguir un buen rendimiento anytime, se puede empezar con un $ {\textstyle T_{i}} $ inicialmente pequeño (p. ej., del 1 % al 10 % del presupuesto total esperado) y multiplicarlo por un factor $ {\textstyle T_{m\hspace{0pt}u\hspace{0pt}l\hspace{0pt}t}} $ (p. ej., $ {\textstyle T_{m\hspace{0pt}u\hspace{0pt}l\hspace{0pt}t} = 2} $) en cada restart. El $ {\textstyle ({i + 1})} $-ésimo restart se dispara cuando $ {\textstyle T_{c\hspace{0pt}u\hspace{0pt}r} = T_{i}} $ poniendo $ {\textstyle T_{c\hspace{0pt}u\hspace{0pt}r}} $ a 0. Un ejemplo de configuración del multiplicador de schedule se da en C.

    Nuestro algoritmo AdamWR propuesto representa AdamW (véase Algoritmo 2) con $ {\textstyle \eta_{t}} $ siguiendo la Ec. (15) y $ {\textstyle \lambda} $ calculado en cada iteración usando la weight decay normalizada descrita en la Sección B.1. Notamos que la weight decay normalizada nos permitió usar una configuración de parámetros constante a lo largo de las ejecuciones cortas y largas realizadas dentro de AdamWR y SGDWR (SGDW con warm restarts).

    Apéndice C Un ejemplo de configuración del multiplicador de schedule

    Un ejemplo de schedule del multiplicador $ {\textstyle \eta_{t}} $ aparece en SuppFigure 1 para $ {\textstyle T_{i = 0} = 100} $ y $ {\textstyle T_{m\hspace{0pt}u\hspace{0pt}l\hspace{0pt}t} = 2} $. Tras los primeros 100 epochs la learning rate llegará a 0 porque $ {\textstyle \eta_{t = 100} = 0} $. Entonces, dado que $ {\textstyle T_{c\hspace{0pt}u\hspace{0pt}r} = T_{i = 0}} $, hacemos un restart fijando $ {\textstyle T_{c\hspace{0pt}u\hspace{0pt}r} = 0} $, lo que provoca que el multiplicador $ {\textstyle \eta_{t}} $ se reinicie a 1 por la Ec. (15). Este multiplicador volverá a decrecer de 1 a 0, pero ahora a lo largo de 200 epochs, ya que $ {\textstyle T_{i = 1} = {T_{i = 0}\hspace{0pt}T_{m\hspace{0pt}u\hspace{0pt}l\hspace{0pt}t}} = 200} $. Las soluciones obtenidas justo antes de los reinicios, cuando $ {\textstyle \eta_{t} = 0} $ (p. ej., en los índices de epoch 100, 300, 700 y 1500 mostrados en SuppFigure 1), se recomiendan por el optimizador como las soluciones, dando prioridad a las más recientes.

    Refer to caption

    Apéndice D Resultados adicionales

    Investigamos si el uso de ejecuciones mucho más largas (1800 epochs) de «Adam estándar» (Adam con regularización L2 y learning rate fija) hace innecesario el uso de cosine annealing. La SuppFigure 2 muestra los resultados de Adam estándar para una rejilla logarítmica de 4×4 de configuraciones de hiperparámetros (la baja resolución de la rejilla se debe al alto coste computacional de las ejecuciones de 1800 epochs). Incluso teniendo en cuenta la baja resolución de la rejilla, los resultados parecen ser, en el mejor de los casos, comparables a los obtenidos con AdamW con 18 veces menos epochs y una red más pequeña (véase SuppFigure 3, fila superior, en el medio). Estos resultados no son muy sorprendentes a la vista de la Figura 1 del artículo principal (que muestra tanto las mejoras posibles al usar algún schedule de learning rate, como cosine annealing, como la efectividad de la weight decay desacoplada).

    Nuestros resultados experimentales con Adam y SGD sugieren que el runtime total en términos del número de epochs afecta a la cuenca de hiperparámetros óptimos (véase SuppFigure 3). Más concretamente, cuanto mayor sea el número total de epochs, menores deberían ser los valores de la weight decay. La SuppFigure 4 muestra que nuestro remedio para este problema, la weight decay normalizada definida en la Ec. (15), simplifica la selección de hiperparámetros porque los valores óptimos observados para ejecuciones cortas son similares a los de ejecuciones mucho más largas. Usamos nuestros experimentos iniciales en CIFAR-10 para sugerir la normalización por raíz cuadrada que propusimos en la Ec. (15) y comprobamos doblemente que esto no es una coincidencia en el dataset ImageNet32x32 (Chrabaszcz et al., 2017), una versión submuestreada del dataset original ImageNet con 1,2 millones de imágenes de 32$ {\textstyle \times} $32 píxeles, donde un epoch es 24 veces más largo que en CIFAR-10. Este experimento también respaldó el escalado por raíz cuadrada: los mejores valores de la weight decay normalizada observados en CIFAR-10 representaban valores casi óptimos para ImageNet32x32 (véase SuppFigure 3). Por el contrario, si hubiéramos usado los mismos valores brutos de weight decay $ {\textstyle \lambda} $ para ImageNet32x32 y CIFAR-10 con el mismo número de epochs, sin la normalización propuesta, $ {\textstyle \lambda} $ habría sido aproximadamente 5 veces demasiado grande para ImageNet32x32, dando lugar a un rendimiento mucho peor. Los valores óptimos de weight decay normalizada también fueron muy similares (p. ej., $ {\textstyle \lambda_{n\hspace{0pt}o\hspace{0pt}r\hspace{0pt}m} = 0.025} $ y $ {\textstyle \lambda_{n\hspace{0pt}o\hspace{0pt}r\hspace{0pt}m} = 0.05} $) entre SGDW y AdamW. Estos resultados muestran claramente que normalizar la weight decay puede mejorar sustancialmente el rendimiento; aunque el escalado por raíz cuadrada funcionó muy bien en nuestros experimentos, subrayamos que estos no fueron muy exhaustivos y que es probable que existan reglas de escalado aún mejores.

    La SuppFigure 4 es el equivalente de la Figura 3 del artículo principal, pero para ImageNet32x32 en lugar de CIFAR-10. Los resultados cualitativos son idénticos: la weight decay produce un mejor training loss (cross-entropy) que la regularización L2, y una mejora aún mayor del error de test.

    La SuppFigure 5 y la SuppFigure 6 son los equivalentes de la Figura 4 del artículo principal, pero complementadas con curvas de training loss en su fila inferior. Los resultados muestran que Adam y sus variantes con weight decay desacoplada convergen más rápido (en términos de training loss) en CIFAR-10 que las correspondientes variantes de SGD (la diferencia para ImageNet32x32 es pequeña). Como se discute en el artículo principal, cuando se consideran los mismos valores de training loss, AdamW muestra mejores valores de error de test que Adam. Curiosamente, la SuppFigure 5 y la SuppFigure 6 muestran que las variantes con restart, AdamWR y SGDWR, también muestran mejor generalización que AdamW y SGDW, respectivamente.

    Refer to caption

    Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption Refer to caption

    Refer to caption    Refer to caption
      
    Refer to caption    Refer to caption

    Refer to caption Refer to caption

    Refer to caption Refer to caption