Una delicata introduzione all’apprendimento per rinforzo profondo in JAX |  di Ryan Pegoud |  Novembre 2023

 | Intelligenza-Artificiale

Implementare la funzione di aggiornamento per DQN è leggermente più complesso, analizziamolo:

  • Prima il _loss_fn La funzione implementa l’errore quadrato descritto in precedenza per a singola esperienza.
  • Poi, _batch_loss_fn funge da involucro per _loss_fn e lo decora con vmapapplicando la funzione di perdita ad a lotto di esperienze. Restituiamo quindi l’errore medio per questo batch.
  • Finalmente, update agisce come strato finale della nostra funzione di perdita, calcolandola pendenza rispetto ai parametri di rete online, ai parametri di rete target e a una serie di esperienze. Usiamo quindi Optax (una libreria JAX comunemente utilizzata per l’ottimizzazione) per eseguire un passaggio di ottimizzazione e aggiornare i parametri online.

Si noti che, analogamente al buffer di riproduzione, il modello e l’ottimizzatore lo sono funzioni pure modificando un stato esterno. La riga seguente serve come un buon esempio di questo principio:

updates, optimizer_state = optimizer.update(grads, optimizer_state)

Questo spiega anche perché possiamo utilizzare un unico modello sia per la rete online che per quella target, poiché i parametri vengono memorizzati e aggiornati esternamente.

# target network predictions
self.model.apply(target_net_params, None, state)
# online network predictions
self.model.apply(online_net_params, None, state)

Per contesto, il modello che utilizziamo in questo articolo è a percettrone multistrato definito come segue:

N_ACTIONS = 2
NEURONS_PER_LAYER = (64, 64, 64, N_ACTIONS)
online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)

@hk.transform
def model(x):
# simple multi-layer perceptron
mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
return mlp(x)

online_net_params = model.init(online_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))

prediction = model.apply(online_net_params, None, state)

Buffer di riproduzione

Ora facciamo un passo indietro e guardiamo più da vicino i buffer di replay. Sono ampiamente utilizzati nell’apprendimento per rinforzo per una serie di motivi:

  • Generalizzazione: Campionando dal buffer di replay, interrompiamo la correlazione tra esperienze consecutive mescolandone l’ordine. In questo modo, evitiamo un adattamento eccessivo a sequenze specifiche di esperienze.
  • Diversità: poiché il campionamento non è limitato alle esperienze recenti, generalmente osserviamo una varianza inferiore negli aggiornamenti ed evitiamo un adattamento eccessivo alle esperienze più recenti.
  • Maggiore efficienza del campione: Ogni esperienza può essere campionata più volte dal buffer, consentendo al modello di apprendere di più dalle singole esperienze.

Infine, possiamo utilizzare diversi schemi di campionamento per il nostro buffer di replay:

  • Campionamento uniforme: Le esperienze vengono campionate in modo uniforme e casuale. Questo tipo di campionamento è semplice da implementare e consente al modello di apprendere dalle esperienze indipendentemente dal momento in cui sono state raccolte.
  • Campionamento prioritario: Questa categoria include diversi algoritmi come Riproduzione dell’esperienza prioritaria (“PER”, Schaul et al. 2015) O Riproduzione dell’esperienza sfumata (“GER”, Lahire et al., 2022). Questi metodi tentano di dare priorità alla selezione delle esperienze secondo alcuni parametri legati alla loro “potenziale di apprendimento” (l’ampiezza dell’errore TD per PER e la norma del gradiente dell’esperienza per GER).

Per motivi di semplicità, in questo articolo implementeremo un buffer di replay uniforme. Tuttavia, ho intenzione di coprire ampiamente il campionamento prioritario in futuro.

Come promesso, il buffer di replay uniforme è abbastanza semplice da implementare, tuttavia esistono alcune complessità legate all’uso di JAX e alla programmazione funzionale. Come sempre, dobbiamo lavorare funzioni pure che sono privo di effetti collaterali. In altre parole, non ci è consentito definire il buffer come un’istanza di classe con uno stato interno variabile.

Invece, inizializziamo a buffer_state dizionario che mappa le chiavi su array vuoti con forme predefinite, poiché JAX richiede array di dimensioni costanti durante la compilazione jit del codice su XLA.

buffer_state = {
"states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
"actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
"rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
"next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
"dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}

Useremo a UniformReplayBuffer classe per interagire con lo stato del buffer. Questa classe ha due metodi:

  • add: apre una tupla di esperienza e associa i suoi componenti a un indice specifico. idx = idx % self.buffer_size assicura che quando il buffer è pieno, l’aggiunta di nuove esperienze sovrascrive quelle più vecchie.
  • sample: Campiona una sequenza di indici casuali dalla distribuzione casuale uniforme. La lunghezza della sequenza è impostata da batch_size mentre il range degli indici è (0, current_buffer_size-1). Ciò garantisce di non campionare array vuoti mentre il buffer non è ancora pieno. Infine, utilizziamo JAX vmap in combinazione con tree_map per restituire un lotto di esperienze.

Ora che il nostro agente DQN è pronto per la formazione, implementeremo rapidamente un ambiente CartPole vettorizzato utilizzando lo stesso framework introdotto in un articolo precedente. CartPole è un ambiente di controllo con a ampio spazio di osservazione continua, il che rende rilevante testare il nostro DQN.

Rappresentazione visiva dell’ambiente CartPole (crediti e documentazione: Palestra OpenAIlicenza MIT)

Il processo è abbastanza semplice, ne riutilizziamo la maggior parte Implementazione del Gymnasium di OpenAI assicurandoci di utilizzare array JAX e flusso di controllo lassista invece di alternative Python o Numpy, ad esempio:

# Python implementation
force = self.force_mag if action == 1 else -self.force_mag
# Jax implementation
force = lax.select(jnp.all(action) == 1, self.force_mag, -self.force_mag) )

# Python
costheta, sintheta = math.cos(theta), math.sin(theta)
# Jax
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)

# Python
if not terminated:
reward = 1.0
...
else:
reward = 0.0
# Jax
reward = jnp.float32(jnp.invert(done))

Per brevità, il codice completo dell’ambiente è disponibile qui:

L’ultima parte della nostra implementazione di DQN è il ciclo di formazione (chiamato anche roll-out). Come accennato negli articoli precedenti, dobbiamo rispettare un formato specifico per sfruttare la velocità di JAX.

La funzione di implementazione potrebbe sembrare inizialmente scoraggiante, ma la maggior parte della sua complessità è puramente sintattica poiché abbiamo già trattato la maggior parte degli elementi costitutivi. Ecco una procedura dettagliata sullo pseudo-codice:

1. Initialization:
* Create empty arrays that will store the states, actions, rewards
and done flags for each timestep. Initialize the networks and optimizer
with dummy arrays.
* Wrap all the initialized objects in a val tuple

2. Training loop (repeat for i steps):
* Unpack the val tuple
* (Optional) Decay epsilon using a decay function
* Take an action depending on the state and model parameters
* Perform an environment step and observe the next state, reward
and done flag
* Create an experience tuple (state, action, reward, new_state, done)
and add it to the replay buffer
* Sample a batch of experiences depending on the current buffer size
(i.e. sample only from experiences that have non-zero values)
* Update the model parameters using experience batch
* Every N steps, update the target network's weights
(set target_params = online_params)
* Store the experience's values for the current episode and return
the updated `val` tuple

Ora possiamo eseguire DQN per 20.000 passi e osservare le esibizioni. Dopo circa 45 episodi, l’agente riesce ad ottenere prestazioni discrete, bilanciando il palo per più di 100 passi consecutivi.

IL barre verdi indicano che l’agente è riuscito a bilanciare il palo più di 200 passi, risolvere l’ambiente. In particolare, l’agente ha stabilito il suo record sul 51esimo episodiocon 393 passi.

Rapporto sulle prestazioni per DQN (realizzato dall’autore)

IL 20.000 passi di allenamento furono giustiziati in poco più di un secondoad un tasso di 15.807 passi al secondo (su a singola CPU)!

Queste prestazioni suggeriscono le impressionanti capacità di scalabilità di JAX, consentendo ai professionisti di eseguire esperimenti parallelizzati su larga scala con requisiti hardware minimi.

Running for 20,000 iterations: 100%|██████████| 20000/20000 (00:01<00:00, 15807.81it/s)

Daremo un’occhiata più da vicino procedure di rollout parallelizzate correre statisticamente significante esperimenti e ricerche di iperparametri in un prossimo articolo!

Nel frattempo, sentiti libero di riprodurre l’esperimento e dilettarti con gli iperparametri utilizzando questo taccuino:

Come sempre, grazie per aver letto fin qui! Spero che questo articolo abbia fornito una buona introduzione a Deep RL in JAX. Se hai domande o feedback relativi al contenuto di questo articolo, assicurati di farcelo sapere, sono sempre felice di fare una piccola chiacchierata 😉

Alla prossima volta 👋

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *