Vettorializza e parallelizza ambienti RL con JAX: Q-learning alla velocità della luce⚡ |  di Ryan Pegoud |  Ottobre 2023

 | Intelligenza-Artificiale

JAX lo è ancora un altro Framework Python Deep Learning sviluppato da Google e ampiamente utilizzato da aziende come DeepMind.

“JAX lo è Autograd (differenziazione automatica) e XL (Accelerated Linear Algebra, un compilatore TensorFlow), riuniti per il calcolo numerico ad alte prestazioni. — Documentazione ufficiale

A differenza di ciò a cui è abituata la maggior parte degli sviluppatori Python, JAX non abbraccia l’ programmazione orientata agli oggetti (OOP), ma piuttosto programmazione funzionale (FP)(1).

In parole povere, si basa su funzioni pure (deterministico E senza effetti collaterali) E strutture dati immutabili (invece di modificare i dati in atto, nuove strutture dati Sono creato con le modifiche desiderate) come elementi costitutivi primari. Di conseguenza, FP incoraggia un approccio più funzionale e matematico alla programmazione, rendendolo adatto per attività come il calcolo numerico e l’apprendimento automatico.

Illustriamo le differenze tra questi due paradigmi esaminando lo pseudocodice per una funzione Q-update:

  • IL orientato agli oggetti l’approccio si basa su a istanza di classe contenente vari variabili di stato (come i valori Q). La funzione update è definita come un metodo di classe that aggiorna il stato interno dell’istanza.
  • IL programmazione funzionale l’approccio si basa su a pura funzione. In effetti, questo aggiornamento Q lo è deterministico poiché i valori Q vengono passati come argomento. Pertanto, qualsiasi chiamata a questa funzione con il file stessi input risulterà nel stesse uscite mentre gli output di un metodo di classe possono dipendere dallo stato interno dell’istanza. Anche, strutture dati come lo sono gli array definito E modificata nel portata globale.
Implementazione di un aggiornamento Q in Programmazione orientata agli oggetti E Programmazione Funzionale (fatto dall’autore)

In quanto tale, JAX offre una varietà di decoratori di funzioni che sono particolarmente utili nel contesto di RL:

  • vmap (mappa vettorizzata): Permette di applicare una funzione che agisce su un singolo campione su a lotto. Ad esempio, se env.passo() è una funzione che esegue un passo in un singolo ambiente, vmap(env.step)() è una funzione che esegue un passaggio ambienti multipli. In altre parole, vmap aggiunge un file dimensione del lotto ad una funzione.
Illustrazione di a fare un passo funzione vettorizzata utilizzando vmap (fatto dall’autore)
  • JEsso (compilazione just-in-time): consente a JAX di eseguire un “Compilazione Just In Time di una funzione JAX Python” realizzandolo Compatibile con XLA. In sostanza, l’uso di jit ci consente di farlo compilare funzioni e fornisce miglioramenti significativi della velocità (in cambio di qualche sovraccarico aggiuntivo durante la prima compilazione della funzione).
  • pmap (mappa parallela): Analogamente a vmap, pmap consente una facile parallelizzazione. Tuttavia, invece di aggiungere una dimensione batch a una funzione, replica la funzione e la esegue diversi dispositivi XLA. Nota: quando si applica pmap, viene applicato anche jit automaticamente.
Illustrazione di a fare un passo funzione parallelizzata utilizzando pmap (fatto dall’autore)

Ora che abbiamo gettato le basi di JAX, vedremo come ottenere enormi accelerazioni vettorizzando gli ambienti.

Innanzitutto, cos’è un ambiente vettorizzato e quali problemi risolve la vettorizzazione?

Nella maggior parte dei casi, gli esperimenti RL lo sono rallentato di Trasferimenti dati CPU-GPU. Algoritmi RL di Deep Learning come Ottimizzazione della politica prossimale (PPO) utilizzano le reti neurali per approssimare la politica.

Come sempre nel Deep Learning, le reti neurali utilizzano GPU A formazione E inferenza tempo. Tuttavia, nella maggior parte dei casi, ambienti correre su processore (anche nel caso di più ambienti utilizzati in parallelo).

Ciò significa che il consueto ciclo RL di selezione delle azioni tramite la politica (reti neurali) e di ricezione di osservazioni e ricompense dall’ambiente richiede continui avanti e indietro tra la GPU e la CPU, che danneggia le prestazioni.

Inoltre, utilizzando framework come PyTorch senza “jitting” potrebbe causare un sovraccarico, poiché la GPU potrebbe dover attendere che Python invii osservazioni e ricompense dalla CPU.

Consueta configurazione dell’addestramento in batch RL in PyTorch (fatto dall’autore)

D’altra parte, JAX ci consente di eseguire facilmente ambienti batch sulla GPU, eliminando l’attrito causato dal trasferimento dei dati GPU-CPU.

Inoltre, poiché jit compila il nostro codice JAX in XLA, l’esecuzione non è più (o almeno meno) influenzata dall’inefficienza di Python.

Configurazione dell’addestramento in batch RL in JAX (fatto dall’autore)

Per maggiori dettagli e interessanti applicazioni a ricerca RL sul metaapprendimentoConsiglio vivamente questo post sul blog di Chris Lu.

Diamo un’occhiata all’implementazione delle diverse parti del nostro esperimento RL. Ecco una panoramica di alto livello delle funzioni di base di cui avremo bisogno:

Metodi di classe richiesti per una semplice configurazione RL (realizzati dall’autore)

L’ambiente

Questa implementazione segue lo schema previsto da Nikolaj Goodger nel suo fantastico articolo sugli ambienti di scrittura in JAX.

Cominciamo con a visione di alto livello dell’ambiente e dei suoi metodi. Questo è un piano generale per implementare un ambiente in JAX:

Diamo uno sguardo più da vicino ai metodi della classe (come promemoria, le funzioni che iniziano con “_” sono privato e non deve essere chiamato al di fuori dell’ambito del corso):

  • _get_obs: questo metodo converte lo stato dell’ambiente in un’osservazione per l’agente. In un parzialmente osservabile O Stocastico ambientale, le funzioni di trattamento applicate allo Stato andrebbero qui.
  • _Ripristina: Poiché eseguiremo più agenti in parallelo, abbiamo bisogno di un metodo per i ripristini individuali al completamento di un episodio.
  • _reimposta_se_fatto: questo metodo verrà chiamato ad ogni passaggio e attiverà _reset se il flag “fatto” è impostato su True.
  • Ripristina: Questo metodo viene chiamato all’inizio dell’esperimento per ottenere lo stato iniziale di ciascun agente, nonché le chiavi casuali associate
  • fare un passo: Dato uno stato e un’azione, l’ambiente restituisce un’osservazione (nuovo stato), una ricompensa e il flag “fatto” aggiornato.

In pratica, un’implementazione generica di un ambiente GridWorld sarebbe simile a questa:

Si noti che, come accennato in precedenza, tutti i metodi della classe seguono il metodo programmazione funzionale paradigma. In effetti, non aggiorniamo mai lo stato interno dell’istanza della classe. Inoltre, il attributi di classe sono tutti costanti che non verrà modificato dopo l’istanziazione.

Diamo uno sguardo più da vicino:

  • __dentro__: Nel contesto del nostro GridWorld, le azioni disponibili sono (0, 1, 2, 3). Queste azioni vengono tradotte in un array bidimensionale utilizzando movimenti.di.sé e aggiunto allo stato nella funzione passo.
  • _get_obs: Il nostro ambiente lo è deterministico E pienamente osservabilepertanto l’agente riceve direttamente lo stato anziché un’osservazione elaborata.
  • _reimposta_se_fatto: L’argomento env_state corrisponde alla tupla (stato, chiave) dove chiave è a jax.random.PRNGKey. Questa funzione restituisce semplicemente lo stato iniziale se il file Fatto flag è impostato su True, tuttavia, non possiamo utilizzare il flusso di controllo Python convenzionale all’interno delle funzioni jitted JAX. Utilizzando jax.lax.cond essenzialmente otteniamo un’espressione equivalente a:
def cond(condition, true_fun, false_fun, operand):
if condition: # if done flag == True
return true_fun(operand) # return self._reset(key)
else:
return false_fun(operand) # return env_state
  • fare un passo: Convertiamo l’azione in un movimento e la aggiungiamo allo stato corrente (jax.numpy.clip garantisce che l’agente rimanga all’interno della rete). Aggiorniamo quindi il env_state tuple prima di verificare se è necessario reimpostare l’ambiente. Poiché la funzione step viene utilizzata frequentemente durante l’allenamento, il jitting consente notevoli miglioramenti delle prestazioni. IL @partial(jit, static_argnums=(0, ) decoratore segnala che il “se stesso” dovrebbe essere considerato l’argomento del metodo della classe statico. In altre parole, il le proprietà della classe sono costanti e non cambierà durante le chiamate successive alla funzione step.

Agente Q-Learning

L’agente Q-learning è definito da aggiornamento funzione, nonché una funzione statica tasso di apprendimento E fattore di sconto.

Ancora una volta, quando eseguiamo la funzione di aggiornamento, passiamo l’argomento “self” come statico. Inoltre, si noti che il q_values la matrice viene modificata sul posto utilizzando impostato() e il suo valore non è memorizzato come attributo di classe.

Politica Epsilon-Greedy

Infine, la politica utilizzata in questo esperimento è lo standard politica avida di epsilon. Un dettaglio importante è che utilizza tie-break casualiil che significa che se il valore Q massimo non è unico, l’azione lo sarà campionato uniformemente dal valori Q massimi (l’uso di argmax restituirebbe sempre la prima azione con il valore Q massimo). Ciò è particolarmente importante se i valori Q vengono inizializzati come una matrice di zeri, poiché l’azione 0 (sposta a destra) verrebbe sempre selezionata.

Altrimenti, la politica può essere riassunta da questo frammento:

action = lax.cond(
explore, # if p < epsilon
_random_action_fn, # select a random action given the key
_greedy_action_fn, # select the greedy action w.r.t Q-values
operand=subkey, # use subkey as an argument for the above funcs
)
return action, subkey

Tieni presente che quando usiamo a chiave in JAX (ad esempio qui abbiamo campionato un float casuale e utilizzato random.choice) è pratica comune dividere la chiave in seguito (ovvero “passare a un nuovo stato casuale”, maggiori dettagli Qui).

Ora che disponiamo di tutti i componenti richiesti, addestriamo un singolo agente.

Ecco un Divinatorio ciclo di training, come puoi vedere stiamo essenzialmente selezionando un’azione utilizzando la policy, eseguendo un passaggio nell’ambiente e aggiornando i valori Q, fino alla fine di un episodio. Quindi ripetiamo il processo per N Episodi. Come vedremo tra poco, questo modo di formare un agente è abbastanza inefficientetuttavia, riassume i passaggi chiave dell’algoritmo in modo leggibile:

Su una singola CPU completiamo 10.000 episodi in 11 secondi, al ritmo di 881 episodi e 21.680 passi al secondo.

100%|██████████| 10000/10000 (00:11<00:00, 881.86it/s)
Total Number of steps: 238 488
Number of steps per second: 21 680

Ora replichiamo lo stesso ciclo di training utilizzando la sintassi JAX. Ecco una descrizione di alto livello di srotolare funzione:

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *