Nel precedente articolo sull’addestramento di modelli su larga scala, abbiamo esaminato LoRA. In questo articolo esamineremo un’altra strategia adottata da diversi grandi modelli linguistici per una formazione efficiente: Grouped Query Attention (GQA). In breve, Grouped Query Attention (GQA) è una generalizzazione dell’attenzione multi-head (MHA) e dell’attenzione multi-query (MQA), ciascuna delle quali è un caso speciale di GQA. Pertanto, prima di immergerci nell’attenzione alle query raggruppate, rivisitiamo la tradizionale attenzione multitesta proposta da Vaswani et al. nel fondamentale articolo “L’attenzione è tutto ciò di cui hai bisogno”. Successivamente, esploreremo l’attenzione multi-query e il modo in cui affronta le sfide con MHA. Infine, risponderemo alle domande “Che cos’è la GQA?” e “Come ci offre il meglio di entrambi i mondi?”
L’attenzione di più teste è una componente fondamentale dei modelli Transformer, poiché consente loro di elaborare e comprendere in modo efficiente sequenze complesse in attività come la traduzione linguistica, il riepilogo e altro ancora. Per coglierne le complessità, dobbiamo approfondire le basi matematiche e capire come funzionano più teste nel meccanismo di attenzione.
Il meccanismo di attenzione di base calcola una somma ponderata di valori, con pesi che dipendono da una query e da un set di chiavi. Matematicamente, questo è espresso come:
Questo è indicato come attenzione al prodotto punto in scala. In questa equazione, Q (query) e K (key) sono matrici che rappresentano le query e le chiavi. V (Valore) è la matrice dei valori. “d_k” è la dimensionalità delle chiavi, utilizzata per il ridimensionamento.
Espansione con Multi-Head Attention (MHA)
L’attenzione multi-testa impiega più “teste” di livelli di attenzione, consentendo al modello di occuparsi delle informazioni provenienti da diversi sottospazi di rappresentazione. In ciascuna testa è presente un insieme indipendente di livelli lineari (matrici di proiezione) per query, chiave e valori (questo è un punto importante che rivisiteremo in GQA). Per ogni testa (numerata h):
headʰ = Attenzione(Q.WqʰK.WkʰV.Wvʰ)
Concatenazione degli output della testa
Gli output delle singole teste vengono concatenati e poi trasformati linearmente.
Multitesta(Q,K,V) = Concat(testa¹,testa²,…,testaʰ) .Wᵒ
Wᵒ è un’altra matrice di peso che trasforma linearmente il vettore concatenato nella dimensione di output finale.
L’intuizione dietro l’attenzione multi-testa è che applicando il meccanismo dell’attenzione più volte in parallelo, il modello può catturare diversi tipi di relazioni nei dati.
Tuttavia, MHA consente una comprensione sfumata delle relazioni tra le diverse parti dell’input. Tuttavia, questa complessità ha un costo: una richiesta significativa di larghezza di banda della memoria, soprattutto durante l’inferenza del decodificatore.
La sfida della larghezza di banda della memoria nell’attenzione multi-testa
Il nocciolo del problema risiede nel sovraccarico della memoria. Ogni fase di decodifica nei modelli autoregressivi come Transformers richiede il caricamento dei pesi del decodificatore insieme a tutte le chiavi e i valori di attenzione. Questo processo non richiede solo un uso intensivo del calcolo, ma richiede anche un uso intensivo della larghezza di banda della memoria. Con l’aumento delle dimensioni dei modelli, aumenta anche questo sovraccarico, rendendo il dimensionamento un compito sempre più arduo.
L’attenzione multi-query (MQA) è emersa come soluzione per mitigare questo collo di bottiglia. L’idea è semplice ma efficace: utilizzare più query head ma solo una singola chiave e un valore head. Questo approccio riduce significativamente il carico di memoria, migliorando la velocità di inferenza. È stato impiegato in numerosi modelli su larga scala come PaLM, StarCoder e Falcon.
Nell’attenzione multi-query, calcoliamo la media delle intestazioni per chiavi e valori in modo che tutte le intestazioni delle query condividano la stessa chiave e la stessa intestazione del valore. Ciò si ottiene replicando H volte la “testa” del pool medio, dove H è il numero di teste della query.
Fonte: towardsdatascience.com