L'attenzione flash è un meccanismo di attenzione del trasformatore di ottimizzazione della potenza che fornisce un'efficienza del 15%.
Flash Attention è un meccanismo di attenzione del trasformatore di ottimizzazione della potenza che fornisce un'efficienza del 15% in termini di velocità dell'orologio da parete senza approssimazione.
Dato che i modelli del trasformatore sono lenti e affamati di memoria su sequenze lunghe (la complessità del tempo e della memoria è di natura quadratica), l'attenzione flash (carta) fornisce un'accelerazione end-to-end del 15% su BERT-large e una velocità 3x su GPT-2.
Considerando l'enorme quantità di energia consumata nell'addestramento di questi modelli di grandi dimensioni, l'attenzione di Flash con l'ottimizzazione del software e dell'hardware è in grado di fornire un'efficienza del 15%, il che rappresenta un enorme vantaggio in termini di miglioramento.
Di seguito, la discussione aiuta a spiegare alcuni dei concetti di base alla base dell'attenzione flash e come viene implementata.
Concetti di base su calcolo e memoria
Prima di approfondire l'analisi del calcolo e della memoria, rivediamoli:
Cos'è il calcolo?
- Tempo impiegato sulla GPU per il calcolo delle operazioni effettive in virgola mobile (FLOPS)
Cos'è la memoria?
- Tempo impiegato per trasferire i tensori all'interno di una GPU
Idealmente, vogliamo che la nostra gCPU esegua continuamente la moltiplicazione di matrici e non sia limitata dalla memoria. Ma in realtà, l’elaborazione ha fatto più progressi rispetto alla memoria e ci troviamo in un mondo in cui la gCPU rimane inattiva in attesa che i dati vengano caricati. Questo di solito viene chiamato legato alla memoria operazione. Fare riferimento di seguito al diagramma illustrativo che illustra questo. La moltiplicazione di matrici è considerata calcolo e la memoria memorizza i dati (considerandoli come un magazzino). L'elaborazione necessita di dati da elaborare e la larghezza di banda della memoria deve supportare tale operazione.
Cos'è la gerarchia della memoria?
La GPU A100 ha 40-80 GB di memoria a larghezza di banda elevata con una larghezza di banda di 1,5–2,0 TB/s E 192KB di SRAM su chip con ciascuno 108 multiprocessori di streaming con larghezza di banda stimata intorno 19 TB/s.
Tenendo presente il contesto di cui sopra, l'architettura dell'attenzione al sé è legato alla memoria.
Osservando i calcoli dell'attenzione, si tratta di un'operazione softmax che causa il limite alla memoria.
- Prova quantitativa: come puoi vedere di seguito, operazioni come softmax, dropout e mascheramento richiedono la maggior parte del tempo rispetto alla moltiplicazione di matrici (Matmul)
Perché softmax diventa un'operazione legata alla memoria?
La scala su cui opera è il nostro più grande collo di bottiglia. Nel diagramma sottostante
- N -> numero di token
- d -> numero di dimensioni di incorporamento
- Quando Query e Key' vengono moltiplicati, la matrice dell'attenzione esplode in N * N, il che richiede molta memoria. Per riferimento (d ~128; N ~128.000 token; google gemini: ~1 milione di token)
Di seguito è riportato l'algoritmo per implementare il meccanismo di autoattenzione
Come notato nella sezione precedente, trasferire le informazioni su HBM (scrivere S su HBM) e quindi ricaricare da HBM a gCPU per calcolare softmax e quindi riscrivere su HBM sono molte informazioni che viaggiano rendendolo operazione legata alla memoria.
Insieme al diagramma, i passaggi seguenti aiutano a spiegare come viene calcolata l'attenzione verso se stessi attraverso la moltiplicazione di matrici
Passo 1:
- L'ho semplificato. In pratica, ogni token viene aggiunto con codifica posizionale per generare incorporamenti da inserire in uno strato lineare da generare
. A scopo illustrativo ho utilizzato la dimensione 3 (generalmente è compresa tra 64 e 128). Questo è l'ingresso dell'architettura del trasformatore standard.
Passo 2
- Chiave -> Chiave' (trasposizione) viene calcolata e moltiplicata con Query per fornire QK' che è N*N. Questo contiene l'attenzione di ciascun token con il resto dei token. Il diagramma seguente mostra anche la relazione. Poiché si tratta di token e dobbiamo calcolare l'importanza di ciascun token rispetto all'altro, l'operazione softmax viene applicata in base alle righe per normalizzarla da 0 a 1.
- Questo passaggio richiede lo spostamento alla HBM ed è l'operazione più costosa come abbiamo discusso. L'intero documento di attenzione flash spiega come ottimizzare questo processo.
Passaggio 3
- Softmax(QK') * V viene calcolato come matrice di output finale. La dimensione qui è la stessa degli incorporamenti di input di chiave, query e valore.
- Riga finale nella matrice di output
- 1*5 significa che l'incorporamento di “questo” dovrebbe essere modificato per incorporare le relazioni con altri token.
- 2*5 significa che l'incorporamento di “is” dovrebbe essere modificato per incorporare le relazioni con altri token.
- Come sopra per il resto delle altre righe
L'idea di base è spiegata attraverso il diagramma sottostante in cui i blocchi di chiave, query e valore vengono propagati da HBM a SRAM e attraverso alcuni trucchi matematici (spiegati di seguito), il calcolo eseguito qui non è una risposta approssimativa ma effettivamente corretta.
Con questa implementazione, la carta è in grado di ridurre il tempo wall-speed accedendo alle informazioni in blocchi senza sacrificare la correttezza.
Algoritmo alla base del documento: come viene implementata l'attenzione di Flash?
Questa è la parte più complessa del documento. Suddividiamo questo problema in sottoaspetti e approfondiamo.
Il diagramma seguente suddivide la matrice in blocchi e spiega come ciascun blocco viene utilizzato per calcolare il softmax parziale e quindi correggere il softmax.
- Input iniziale: Token: questo è un documento di attenzione flash
- Chiave: 4 (token) X 3 (dimensioni), Query: 4 (token) X 3 (dimensioni) e Valore: 4 (token) X 3 (dimensioni)
Passaggio 0
- Supponiamo che la memoria sia di 24 byte
- La SRAM sarà divisa in 4 blocchi (Query, Chiave, Valore e matrice di output)
- Query, chiave, valore e output riceveranno = 6 byte ciascuno per memorizzare le proprie informazioni (12 byte/4)
- Ogni dimensione è 3 poiché ogni incorporamento non può essere interrotto, quindi
- Query: 6 byte/ 3 (dimensione) = 2. Lo stesso per valore, chiave e output
- Quindi, (M/4d) dà la dimensione di ciascun blocco. In questo caso, la dimensione del blocco è 2. Ciò significa che è possibile caricare 2 righe nella SRAM.
- In senso generale, la dimensione del blocco è (M/4d) e il numero di blocchi è (N*4D/M)
Passaggio 1 e 2: aggiunta di una tabella di seguito che illustra i passaggi 1 e 2 su come funziona l'attenzione flash e confronta gli aspetti di memoria e calcolo.
Il diagramma seguente aiuta a visualizzare la moltiplicazione della matrice (blocco per blocco) utilizzata nell'attenzione flash.
Qual è l'aspetto matematico di softmax?
Uno degli aspetti più critici del documento riguarda il modo in cui la scomposizione delle matrici si traduce ancora nel calcolo della precisione softmax. Lasciando l'esempio matematico di seguito su come mostrare due matrici diverse, è possibile bastonare per calcolare nuovamente il softmax.
Intuizione
- Questa è la bellissima proprietà degli esponenti che qui viene sfruttata.
- Ogni softmax viene calcolato individualmente ma insieme a questo valore massimo della riga viene memorizzato insieme al valore dell'esponente sommato.
- Quando ci uniamo con un'altra matrice, dobbiamo verificare quanto il massimo differisce dal massimo globale di 2 matrici. E a causa dell'esponente, sia il numeratore che il denominatore vengono regolati con e^(current_max — global_max) per incorporarlo.
La logica è piuttosto complessa e quindi lasciamo un esempio qui sotto da seguire. Una volta familiarizzato con un esempio, l’intuizione di cui sopra avrà molto senso.
Diamo un'occhiata all'analisi della complessità per avere un'idea di come sono cambiate le cose
Attenzione a sé stessi
- Mentre si calcola S = QK' diventa una matrice N*N che deve essere propagata nuovamente all'HRAM e quindi ritirata dall'HRAM.
- Quindi O(N*N + N*N) = O(N*N) è l'accesso HBM
Attenzione lampo
- Ciclo esterno: si accederà a chiave e query O(Nd) volte
- Anello interno: Sarà necessario solo O(Nd/M) per caricare dalla HBM poiché si opera su blocchi
- Complessivamente: O(N*N*d*d/M)
- In pratica, d è molto più piccolo di M. d varia da (64 a 128) mentre M varia da 100 KB e quindi l'accesso HBM è ottimizzato
- Abbiamo iniziato con l'obiettivo di ottimizzare l'accesso HBM e con questa analisi della complessità vediamo che il documento ha ottimizzato l'accesso Accesso HBM tramite fattore (d*d/M) senza approssimazione.
Un documento così complesso con un enorme miglioramento in termini di efficienza. Spero che la spiegazione di cui sopra fornisca qualche intuizione su come l'attenzione del flash ottimizza e migliora le prestazioni. Non ho trattato il blocco dell'attenzione flash sparsa, come si confronta con altre tecniche di ottimizzazione, ottimizzazione dei passaggi in avanti ecc. Spero di trattarlo in un post futuro.
Riferimenti
Fonte: towardsdatascience.com