Gli ingegneri di DeepMind accelerano la nostra ricerca costruendo strumenti, ampliando algoritmi e creando mondi virtuali e fisici stimolanti per l’addestramento e il test dei sistemi di intelligenza artificiale (AI). Come parte di questo lavoro, valutiamo costantemente nuove librerie e framework di machine learning.
Recentemente, abbiamo scoperto che un numero crescente di progetti è ben servito da JAXun framework di machine learning sviluppato da Ricerca Google squadre. JAX si adatta bene alla nostra filosofia ingegneristica ed è stato ampiamente adottato dalla nostra comunità di ricerca nell’ultimo anno. Qui condividiamo la nostra esperienza di lavoro con JAX, spieghiamo perché lo troviamo utile per la nostra ricerca sull’intelligenza artificiale e forniamo una panoramica dell’ecosistema che stiamo costruendo per supportare i ricercatori ovunque.
Perché JAX?
JAX è una libreria Python progettata per il calcolo numerico ad alte prestazioni, in particolare per la ricerca sull’apprendimento automatico. La sua API per le funzioni numeriche è basata su NumPyun insieme di funzioni utilizzate nel calcolo scientifico. Sia Python che NumPy sono ampiamente utilizzati e familiari, rendendo JAX semplice, flessibile e facile da adottare.
Oltre alla sua API NumPy, JAX include un sistema estensibile di trasformazioni di funzioni componibili che aiutano a supportare la ricerca sul machine learning, tra cui:
- Differenziazione: L’ottimizzazione basata sui gradienti è fondamentale per il machine learning. JAX supporta nativamente sia la modalità avanti che quella inversa differenziazione automatica di funzioni numeriche arbitrarie, tramite trasformazioni di funzioni come grad, hessian, jacfwd e jacrev.
- Vettorializzazione: Nella ricerca ML spesso applichiamo una singola funzione a molti dati, ad esempio calcolando la perdita in un lotto o valutazione dei gradienti per esempio per l’apprendimento differenzialmente privato. JAX fornisce la vettorizzazione automatica tramite la trasformazione vmap che semplifica questa forma di programmazione. Ad esempio, i ricercatori non devono ragionare sul batching quando implementano nuovi algoritmi. JAX supporta anche il parallelismo dei dati su larga scala tramite la relativa trasformazione pmap, distribuendo elegantemente i dati troppo grandi per la memoria di un singolo acceleratore.
- Compilazione JIT: XL viene utilizzato per compilare ed eseguire programmi JAX just-in-time (JIT) su GPU e TPU nuvola acceleratori. La compilazione JIT, insieme all’API coerente NumPy di JAX, consente ai ricercatori senza esperienza precedente nel calcolo ad alte prestazioni di scalare facilmente su uno o più acceleratori.
Abbiamo scoperto che JAX ha consentito una rapida sperimentazione con nuovi algoritmi e architetture e ora è alla base di molte delle nostre recenti pubblicazioni. Per saperne di più, prendi in considerazione la possibilità di partecipare alla nostra tavola rotonda JAX, mercoledì 9 dicembre alle 19:00 GMT, presso il NeurIPS conferenza virtuale.
JAX presso DeepMind
Sostenere la ricerca all’avanguardia sull’intelligenza artificiale significa bilanciare la prototipazione rapida e l’iterazione rapida con la capacità di implementare esperimenti su una scala tradizionalmente associata ai sistemi di produzione. Ciò che rende questo tipo di progetti particolarmente impegnativi è che il panorama della ricerca si evolve rapidamente ed è difficile da prevedere. In qualsiasi momento, una nuova svolta nella ricerca può cambiare, e succede regolarmente, la traiettoria e le esigenze di interi team. In questo panorama in continua evoluzione, una responsabilità fondamentale del nostro team di ingegneri è garantire che le lezioni apprese e il codice scritto per un progetto di ricerca vengano riutilizzati in modo efficace in quello successivo.
Un approccio che si è dimostrato efficace è la modularizzazione: estraiamo gli elementi costitutivi più importanti e critici sviluppati in ciascun progetto di ricerca in elementi ben testati ed efficienti componenti. Ciò consente ai ricercatori di concentrarsi sulla propria ricerca beneficiando al tempo stesso del riutilizzo del codice, delle correzioni di bug e dei miglioramenti delle prestazioni negli ingredienti algoritmici implementati dalle nostre librerie principali. Abbiamo anche scoperto che è importante assicurarsi che ciascuna libreria abbia un ambito chiaramente definito e garantire che siano interoperabili ma indipendenti. Buy-in incrementalela capacità di scegliere le funzionalità senza essere vincolati ad altre, è fondamentale per fornire la massima flessibilità ai ricercatori e supportarli sempre nella scelta dello strumento giusto per il lavoro.
Altre considerazioni che hanno riguardato lo sviluppo del nostro ecosistema JAX includono la garanzia che rimanga coerente (ove possibile) con la progettazione del nostro ecosistema esistente TensorFlow biblioteche (es Sonetto E TRFL). Abbiamo anche mirato a costruire componenti che (dove rilevante) corrispondano il più fedelmente possibile alla matematica sottostante, per essere auto-descrittivi e ridurre al minimo i salti mentali “dalla carta al codice”. Alla fine, abbiamo scelto di farlo fonte aperta le nostre biblioteche per facilitare la condivisione dei risultati della ricerca e per incoraggiare la comunità più ampia a esplorare l’ecosistema JAX.
Il nostro ecosistema oggi
Haiku
Il modello di programmazione JAX delle trasformazioni di funzioni componibili può rendere complicato il trattamento di oggetti con stato, ad esempio reti neurali con parametri addestrabili. Haiku è una libreria di reti neurali che consente agli utenti di utilizzare modelli di programmazione orientati agli oggetti familiari sfruttando al tempo stesso la potenza e la semplicità del paradigma funzionale puro di JAX.
L’Haiku è utilizzato attivamente da centinaia di ricercatori su DeepMind e Google e ha già trovato adozione in diversi progetti esterni (ad es Coassiale, DeepChem, NumPyro). Si basa sull’API per Sonettoil nostro modello di programmazione basato su moduli per le reti neurali in TensorFlow e abbiamo mirato a rendere il porting da Sonnet a Haiku il più semplice possibile.
Optax
L’ottimizzazione basata sui gradienti è fondamentale per il machine learning. Optax fornisce una libreria di trasformazioni di gradiente, insieme ad operatori di composizione (ad esempio catena) che consentono di implementare molti ottimizzatori standard (ad esempio RMSProp o Adam) in una sola riga di codice.
La natura compositiva di Optax supporta naturalmente la ricombinazione degli stessi ingredienti di base in ottimizzatori personalizzati. Offre inoltre una serie di utilità per la stima del gradiente stocastico e l’ottimizzazione del secondo ordine.
Molti utenti Optax hanno adottato Haiku ma, in linea con la nostra filosofia di buy-in incrementale, è supportata qualsiasi libreria che rappresenti parametri come strutture ad albero JAX (ad es. Elegia, Lino E Stax). Perfavore guarda Qui per ulteriori informazioni su questo ricco ecosistema di librerie JAX.
Salmone
Molti dei nostri progetti di maggior successo si trovano all’intersezione tra deep learning e apprendimento per rinforzo (RL), noto anche come apprendimento per rinforzo profondo. RLax è una libreria che fornisce elementi costitutivi utili per la costruzione di agenti RL.
I componenti di RLax coprono un ampio spettro di algoritmi e idee: apprendimento TD, gradienti politici, critici degli attori, MAP, ottimizzazione prossimale delle politiche, trasformazione del valore non lineare, funzioni di valore generali e una serie di metodi di esplorazione.
Anche se alcuni introduttivi agenti di esempio vengono forniti, RLax non è inteso come un framework per la creazione e la distribuzione di sistemi di agenti RL completi. Un esempio di framework di agenti completo che si basa su componenti RLax è Acme.
Chex
I test sono fondamentali per l’affidabilità del software e il codice di ricerca non fa eccezione. Per trarre conclusioni scientifiche da esperimenti di ricerca è necessario avere fiducia nella correttezza del codice. Chex è una raccolta di utilità di test utilizzate dagli autori di librerie per verificare che gli elementi costitutivi comuni siano corretti e robusti e dagli utenti finali per verificare il loro codice sperimentale.
Chex fornisce un assortimento di utilità tra cui test unitari compatibili con JAX, asserzioni di proprietà di tipi di dati JAX, simulazioni e falsificazioni e ambienti di test multi-dispositivo. Chex viene utilizzato in tutto l’ecosistema JAX di DeepMind e da progetti esterni come Coassiale E MineRL.
Jrap
Rappresentare graficamente le reti neurali (GNN) rappresentano un’entusiasmante area di ricerca con molte applicazioni promettenti. Vedi, ad esempio, il nostro recente lavoro su previsione del traffico in Google Maps e il nostro lavoro su simulazione fisica. Jraph (pronunciato “giraffa”) è una libreria leggera per supportare il lavoro con GNN in JAX.
Jraph fornisce una struttura dati standardizzata per i grafici, un insieme di utilità per lavorare con i grafici e uno “zoo” di modelli di reti neurali a grafo facilmente divisibili ed estensibili. Altre caratteristiche chiave includono: batch di GraphTuple che sfruttano in modo efficiente gli acceleratori hardware, supporto per la compilazione JIT di grafici a forma variabile tramite riempimento e mascheramento e perdite definite sulle partizioni di input. Come Optax e le nostre altre librerie, Jraph non pone vincoli alla scelta dell’utente di una libreria di rete neurale.
Scopri di più sull’utilizzo della libreria dalla nostra ricca raccolta di esempi.
Il nostro ecosistema JAX è in continua evoluzione e incoraggiamo la comunità di ricerca ML a esplorarlo le nostre biblioteche e il potenziale di JAX per accelerare la propria ricerca.