Il momentum aiuta gli SGD ad affrontare scenari di perdite complessi in modo più efficiente. fotografato da Maxim Berg SU Unsplash.

Se guardi da vicino PyTorch documentazione di SGD, scoprirai che la loro implementazione dello slancio di Nesterov presenta alcune differenze rispetto alla formulazione trovata nel carta originale. In particolare, l’implementazione di PyTorch valuta il gradiente in corrispondenza dei parametri attuali, mentre lo scopo principale del momentum di Nesterov è valutare il gradiente in corrispondenza dei parametri spostati. Sfortunatamente, sembra che la discussione su queste discrepanze su Internet sia scarsa. In questo post esamineremo e spiegheremo le differenze tra l’implementazione di PyTorch e la formulazione originale del momentum di Nesterov. Alla fine, vedremo come l’implementazione di PyTorch non sia sbagliata, ma piuttosto un’approssimazione, e speculeremo sui vantaggi della sua implementazione.

IL carta originale descrive lo slancio di Nesterov utilizzando le seguenti regole di aggiornamento:

Dove v_{t+1} E θ_{t+1} sono rispettivamente il vettore velocità e i parametri del modello al momento T, M è il fattore di slancio e e è il tasso di apprendimento. La nota nell’SGD di PyTorch documentazione afferma di utilizzare le seguenti regole di aggiornamento:

Dove g_{t+1} rappresenta il gradiente utilizzato per il calcolo v_{t+1}. Possiamo espandere la regola di aggiornamento per θ_{t+1} ottenere:

Da ciò possiamo dedurre che:

e le regole di aggiornamento diventano:

Queste sono le regole di aggiornamento che PyTorch utilizza in teoria. Ho menzionato in precedenza che PyTorch valuta effettivamente il gradiente in base ai parametri correnti anziché ai parametri spostati. Questo può essere visto guardando la descrizione dell’algoritmo nella documentazione SGD di PyTorch. Investigheremo ulteriormente questo aspetto in seguito.

Si noti che sia per la formulazione originale (1, 2) che per PyTorch (3, 4), if v_0 = 0quindi il primo aggiornamento a io diventa:

Sebbene la nota di documentazione SGD di PyTorch affermi che l’algoritmo inizializza il momentum buffer sul gradiente nel primo passaggio, mostreremo in seguito che ciò implica v_0 = 0.

Ci sono due differenze immediate quando si passa dalla formulazione originale (1, 2) alla formulazione PyTorch (3, 4):

  1. Il tasso di apprendimento viene spostato all’esterno v_{t+1}.
  2. Nella regola di aggiornamento per v_{t+1}il termine che coinvolge il gradiente viene aggiunto anziché sottratto e nella regola di aggiornamento per θ_{t+1}il termine che coinvolge il vettore velocità viene sottratto anziché aggiunto. La differenza di segno all’interno del termine del gradiente è semplicemente una conseguenza di ciò, come mostrato nella sezione precedente.

Per comprendere queste differenze, espandiamo innanzitutto le regole di aggiornamento. Come accennato Quil’effetto della prima differenza è più evidente se si considerano i ritmi di apprendimento. Quindi, consideriamo una generalizzazione delle regole di aggiornamento dove e non è più fisso ma ora può variare nel tempo e denotare e_t come tasso di apprendimento al passo temporale T. Per brevità, poniamo:

Supponendo v_0 = 0, la formulazione originale diventa:

e la formulazione PyTorch diventa:

Nella formulazione originale (6), se il tasso di apprendimento dovesse cambiare nel tempo Tquindi solo la grandezza del termine a io = t nella somma verrebbero influenzati e le grandezze di tutti gli altri termini rimarrebbero le stesse. Di conseguenza, l’influenza immediata della variazione del tasso di apprendimento è piuttosto limitata e dovremmo attendere che la variazione del tasso di apprendimento “si riversi” nei successivi passaggi temporali per avere un’influenza maggiore sulla dimensione complessiva del passaggio. Al contrario, nella formulazione PyTorch (7), se il tasso di apprendimento dovesse cambiare nel tempo Tallora l’entità dell’intero passo verrebbe influenzata immediatamente.

Per v_0 = 0dalle regole ampliate risulta chiaramente che la seconda differenza in definitiva non ha alcun effetto; in entrambe le formulazioni, il passaggio corrisponde a una somma scontata di gradienti che viene sottratta dai parametri correnti.

Ignorando il decadimento e lo smorzamento del peso, analizzando l’algoritmo SGD in PyTorch documentazionepossiamo vedere che le regole di aggiornamento implementate sono:

Dove θ’_{t+1} sono i parametri del modello in quel momento T E

Faremo riferimento alle equazioni 3 e 4 come formulazione “nota” di PyTorch e alle equazioni 8 e 9 come formulazione “implementata” di PyTorch. Facciamo una distinzione tra io E io’ per un motivo che diventerà evidente presto. La differenza più evidente rispetto alla formulazione delle note è che il gradiente viene valutato in base ai parametri correnti anziché ai parametri spostati. Già da questo potrebbe sembrare che le regole di aggiornamento implementate dall’algoritmo non siano un’implementazione corretta dello slancio di Nesterov.

Esamineremo ora come l’algoritmo PyTorch approssima in ultima analisi lo slancio di Nesterov. È possibile trovare derivazioni per una versione precedente di PyTorch Qui da Ivo Danihelka, citato in questo problema di GitHub. È possibile trovare derivazioni per la versione corrente di PyTorch Quiche è un aggiustamento relativamente semplice rispetto alle derivazioni precedenti. Forniamo qui una resa LaTeX di queste derivazioni (ri-derivate) per comodità del lettore. La formulazione implementata deriva da un semplice cambio di variabili. Nello specifico lasciamo:

Diventa subito chiaro che la regola di aggiornamento delle note per v_{t+1} (3) diventa equivalente alla regola di aggiornamento implementata per v_{t+1} (8) dopo il cambio di variabili. Vogliamo ora derivare una regola di aggiornamento per θ’_{t+1} in termini di θ’_t:

Questa è esattamente la regola di aggiornamento che abbiamo visto implementata in PyTorch (9). Ad alto livello, l’implementazione PyTorch assume i parametri correnti θ’_t sono già la versione spostata dei parametri “effettivi”. θ_t. Quindi, ad ogni passo temporale, i parametri “effettivi”. θ_t sono legati ai parametri attuali θ’_t di:

Tuttavia, dal codice sorgente risulta che l’implementazione SGD di PyTorch non apporta alcuna correzione alla fine dell’algoritmo per recuperare i parametri “effettivi” finali, quindi l’output finale è tecnicamente un’approssimazione dei parametri “effettivi”.

Infine, ora lo mostriamo v_0 deve essere 0:

Inoltre, possiamo confermare che il primo aggiornamento dei parametri “effettivi” è lo stesso primo aggiornamento effettuato nella formulazione originaria quando v_0 = 0:

Possiamo vedere che questo è equivalente all’equazione 5.

Naturalmente, la grande domanda rimasta è: perché PyTorch si preoccupa di riformulare lo slancio di Nesterov dalle equazioni 3 e 4 alle equazioni 8 e 9? Una possibile spiegazione è che la riformulazione potrebbe consentire un certo risparmio nel numero di operazioni aritmetiche richieste. Per valutare questa possibile spiegazione, contiamo il numero di operazioni aritmetiche. Per la formulazione della nota (3, 4), abbiamo:

Qui ci sono un totale di sette operazioni. Per la formulazione implementata (8, 9), abbiamo:

Qui ci sono un totale di sei operazioni. Il secondo gradiente nell’implementazione PyTorch utilizza semplicemente il risultato salvato dal primo calcolo del gradiente, quindi viene eseguito solo un calcolo del gradiente per ogni fase temporale. Quindi, un vantaggio evidente è che l’implementazione di PyTorch riduce un’operazione di moltiplicazione aggiuntiva ad ogni passaggio.

In sintesi:

  1. Le regole di aggiornamento indicate nella nota sulla documentazione SGD di PyTorch (3, 4) hanno una posizione diversa per il tasso di apprendimento rispetto alle regole di aggiornamento del momentum di Nesterov originali (1, 2). Ciò consente alle pianificazioni del tasso di apprendimento di avere un effetto immediato sulla dimensione complessiva del passo, mentre la formulazione originale avrebbe l’effetto di far sì che le modifiche del tasso di apprendimento “si riversino” sulle fasi temporali successive.
  2. Le regole di aggiornamento implementate nell’algoritmo PyTorch SGD (8, 9) sono un’approssimazione alle regole di aggiornamento indicate nella nota di documentazione (3, 4) dopo un semplice cambio di variabili. Sebbene i parametri “effettivi” siano facilmente recuperabili dai parametri correnti in ogni fase temporale, l’implementazione PyTorch non apporta alcuna correzione di questo tipo alla fine dell’algoritmo, quindi i parametri finali rimangono tecnicamente un’approssimazione dei parametri finali “effettivi” .
  3. Un evidente vantaggio dell’implementazione PyTorch è che evita un’ulteriore operazione di moltiplicazione in ogni fase temporale.
  1. “SGD.” SGD: documentazione di PyTorch 2.0, pytorch.org/docs/stable/generated/torch.optim.SGD.html. Accesso effettuato il 2 settembre 2023.
  2. Sutskever, Ilya et al. “Sull’importanza dell’inizializzazione e dello slancio nel deep learning.“Conferenza internazionale sull’apprendimento automatico. PMLR, 2013.
  3. Danielka, Ivo. “Lo slancio di Nesterov reso semplice.“25 agosto 2012.
  4. Chintala, Soumith. “Lo slancio di nesterov è sbagliato in sgd · Numero 27 · torcia/ottim.” GitHub, 13 ottobre 2014, github.com/torch/optim/issues/27.
  5. Lordo, Sam. “Aggiungi una nota nei documenti sulla formulazione dello slancio utilizzata in optim · Numero 1099 · pytorch/pytorch.” GitHub, 25 marzo 2017, github.com/pytorch/pytorch/issues/1099#issuecomment-289190614.
  6. Zhao, Yilong. “correggi il bug di Nesterov Momentum · Problema n. 5920 · pytorch/pytorch.” GitHub, 21 marzo 2018, https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908.

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *