Sfide nella generazione di arresti in Llama 2 |  di Shuyang Xiang |  Settembre 2023

 | Intelligenza-Artificiale

Un’esplorazione con potenziali soluzioni

Lama: foto di Liudmila Shuvalova

Il lancio di Llama 2 da parte di Meta ha acceso l’entusiasmo all’interno della comunità, segnando l’alba di un’era per modelli linguistici di grandi dimensioni ben eseguiti che in precedenza erano accessibili solo tramite API specifiche dell’azienda.

Tuttavia, è importante riconoscere alcune imperfezioni inerenti a questi modelli. Tra questi spicca in modo prominente il problema della generazione degli stop. Le mie esperienze personali hanno dimostrato che questi modelli spesso faticano a determinare il punto di “stop” appropriato, lasciandoli incerti su quando terminare la generazione di un testo.

In questo post del blog, approfondirò il problema dei fallimenti nella generazione degli arresti nel modello Llama 2 più piccolo, il modello Llama 2–7b, e discuterò diversi potenziali rimedi. L’implementazione nelle prossime sezioni può essere trovata in questo GoogleGolab taccuino con il tipo di runtime T4.

In questa sezione, sfrutteremo la potenza di un modello Llama 2–7b utilizzando una GPU T4 dotata di ampie risorse RAM elevate in Google Colab (2,21 crediti/ora). È essenziale tenere presente che la GPU T4 ha una capacità VRAM di 16 GB, sufficiente per ospitare il peso di Llama 2–7b (7b × 2 byte = 14 GB in FP16).

Per gestire in modo efficiente l’utilizzo della VRAM, utilizzeremo una tecnica chiamata quantizzazione. La quantizzazione è un approccio che si concentra sulla minimizzazione dei requisiti sia computazionali che di memoria durante l’inferenza rappresentando pesi e attivazioni utilizzando tipi di dati a bassa precisione.

Esaminiamo ora il seguente frammento di codice. Qui, dimostreremo come caricare il modello “meta-llama/Llama-2–7b-chat-hf” con una configurazione Bite e Byte e impostare una pipeline di generazione di testo basata su questo modello caricato.

# bit and byte configuration
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True, # 4-bit quantization
bnb_4bit_quant_type='nf4', # Normalized float 4
bnb_4bit_use_double_quant=True, # Second quantization after the first
bnb_4bit_compute_dtype=bfloat16 # Computation type
)
model_id= "meta-llama/Llama-2-7b-chat-hf"
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

# Llama 2 Tokenizer
tokenizer =transformers.AutoTokenizer.from_pretrained(model_id)

# Llama 2 Model
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
quantization_config=bnb_config,
device_map='auto',
)
model.eval()

# Difine pipeline
transformers.pipeline(model=model,
tokenizer=tokenizer,
return_full_text=True, # langchain expects the full text
task='text-generation',
# we pass model parameters here too
max_new_tokens=500, # max number of tokens to generate in the output
temperature=0.1,
)

Questa configurazione ci consente di generare testo per un determinato prompt in meno di un minuto. Mettiamolo alla prova con una domanda semplice: “Cosa posso fare a Parigi?” Di seguito troverai la risposta (tieni presente che i risultati potrebbero variare a causa delle impostazioni della temperatura).

Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. everybody's dream destination. Here are some of the top things to do in Paris:

1. Visit iconic landmarks: The Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, Arc de Triomphe, and the Champs-Élysées are must-visit attractions.

2. Explore museums and art galleries: The Louvre, Musée d'Orsay, and Centre Pompidou are some of the world's most famous museums.

3. Take a river cruise: A Seine River cruise is a great way to see the city's landmarks and bridges while learning about its history and culture.

4. Discover the city's neighborhoods: From the fashionable Champs-Élysées to the charming Montmartre, each Parisian neighborhood has its own unique character and charm.

# I skipped some text here because of the limt of space#

Of course, there are countless

È evidente che il modello fatica a produrre una risposta soddisfacente; sembra avere difficoltà a sapere quando concludere la sua produzione. Dopo aver tokenizzato il testo generato, diventa evidente che il token finale non è un 2, che rappresenta il token eos (fine sequenza) nel tokenizer del modello.

Dopo un esame più attento dei punteggi token (probabilità) forniti dal modello, ho notato che the token_id 2 (eso_token_id) ha un punteggio pari a “-inf”. Ciò implica che non ha alcuna possibilità di essere generato.

In questa sezione esploreremo diverse potenziali soluzioni volte ad affrontare il problema in questione. È essenziale tenere presente che le soluzioni qui discusse rappresentano sforzi proattivi, ma potrebbero non sempre fornire soluzioni ai problemi in questione.

Processore Logit

Un modello linguistico come Llama 2 elabora una sequenza di token di testo come input e produce una sequenza di probabilità condizionali per il token successivo, in base al contesto dal token iniziale a quello corrente. Alla luce di ciò, vale la pena considerare aggiustamenti manuali a queste probabilità man mano che ci avviciniamo al limite massimo dei token, con l’obiettivo di aumentare la probabilità di incontrare il token eos. Lo facciamo definendo il nostro LogitsProcessor personalizzato chiamato “EosTokenRewardLogitsProcessor” con due input iniziali eos_token_id e max_length dove quest’ultimo rappresenta la lunghezza massima alla quale il modello dovrebbe generare un token eos:

class EosTokenRewardLogitsProcessor(LogitsProcessor):
def __init__(self, eos_token_id: int, max_length: int):

if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")

if not isinstance(max_length, int) or max_length < 1:
raise ValueError(f"`max_length` has to be a integer bigger than 1, but is {max_length}")

self.eos_token_id = eos_token_id
self.max_length=max_length

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape(-1)
# start to increese the reward of the eos_tokekn from 80% max length progressively on length
for cur_len in (max(0,int(self.max_length*0.8)), self.max_length ):
ratio = cur_len/self.max_length
num_tokens = scores.shape(1) # size of vocab
scores(:, (i for i in range(num_tokens) if i != self.eos_token_id)) =\
scores(:, (i for i in range(num_tokens) if i != self.eos_token_id))*ratio*10*torch.exp(-torch.sign(scores(:, (i for i in range(num_tokens) if i != self.eos_token_id))))
scores(:, self.eos_token_id) = 1e2*ratio
return scores

Nel metodo “__call__” della classe aumentiamo la probabilità (punteggio) dell’eos_token in base alla lunghezza della sequenza. Quando la lunghezza si avvicina all’80% della lunghezza massima specificata, impostiamo il punteggio di eos_token_id su 1e2 moltiplicato per un rapporto di lunghezza e adeguiamo di conseguenza i punteggi degli altri token verso il basso.

Ora dichiara il processore logit nella definizione della pipeline:

pipe = transformers.pipeline(model=model,
tokenizer=tokenizer,
return_full_text=True, # langchain expects the full text
task='text-generation',
# we pass model parameters here too
#stopping_criteria=stopping_criteria, # without this model rambles during chat
logits_processor=logits_process_list,
max_new_tokens=500, # max number of tokens to generate in the output
temperature=0.1,
)

Eseguiamo nuovamente la pipeline con lo stesso prompt “Cosa posso fare a Parigi” e otteniamo:

Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere.

Funziona bene! Abbiamo una risposta completa anche se potrebbe sembrare breve.

Ritocchi

Se il modello non riesce a generare il token EOS, perché non considerare di dargli istruzioni in tal senso? Il concetto di migliorare le prestazioni del modello perfezionandolo con un set di dati che includa le risposte che si concludono con il token EOS è sicuramente una strada promettente da esplorare.

In questa sezione, utilizzerò spudoratamente le basi esposte in questo post del blog che hanno utilizzato un metodo PEFT (parametro-efficient fine-tuning), come QLoRA, per mettere a punto il modello Llama 2–7b. Proprio come il suo predecessore, LoRA, QLoRA utilizza un piccolo set di parametri addestrabili (adattatori) mantenendo invariati i parametri principali del modello. Introduce due innovazioni degne di nota: NormalFloat (NF4) a 4 bit, un metodo di quantizzazione dei dati teoricamente ottimale per i dati normali e la doppia quantizzazione. Per una comprensione più approfondita si prega di consultare il carta originalese dovessi avere ulteriore interesse per questo argomento.

Addestriamo il modello su un set di dati chiamato “timdettmers/openassistant-guanaco” dove puoi trovare il database dei volti abbracciati. Questo set di dati ha il seguente formato in cui la conversazione tra la persona e l’assistente è separata da “###”.

Autore dell’immagine: “timdettmers/openassistant-guanaco’/ dataset

Prima dell’addestramento, dobbiamo trasformare i dati nel modello di prompt di Llama 2:

<s>(INST) <<SYS>>
{your_system_message}
<</SYS>> {user_message_1} (/INST)

Tralascerò qui i dettagli della trasformazione del set di dati. Vediamo ora la parte principale della formazione data dal seguente codice:

# Load LoRA configuration
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
)

# Set supervised fine-tuning parameters
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=packing,
)

# Train model
trainer.train()

Nel contesto di un set di dati comprendente istruzioni e risposte, il nostro approccio prevedeva l’uso di un formatore supervisionato (SFTainer) in combinazione con il metodo QLoRA per ottimizzare i parametri di peso all’interno del modello linguistico (LLM). Il nostro obiettivo principale era ridurre al minimo le discrepanze tra le risposte generate e le risposte basate sulla verità, che fungevano da etichette di riferimento.

Un parametro significativo all’interno di questa configurazione è “lora r”, che rappresenta un valore relativamente piccolo pertinente sia alla seconda che alla prima dimensione delle coppie di matrici dei pesi di decomposizione dei ranghi. L’addestramento è avvenuto esclusivamente su queste due matrici, integrando i pesi esistenti.

Addestriamo il modello per 250 passaggi con la perdita di addestramento indicata nel grafico seguente:

Immagine dell’autore: perdita di allenamento del lama 2 per 250 passi

Ora eseguiamo la pipeline con il modello ottimizzato. Questa volta otteniamo:

 Paris, the capital of France, is known for its stunning architecture, rich history, and cultural attractions. Here are some popular things to do in Paris:

1. Visit iconic landmarks: The Eiffel Tower, Arc de Triomphe, Notre-Dame Cathedral, and the Louvre Museum are some of the most famous landmarks in Paris.
2. Explore museums and art galleries: Paris is home to many world-class museums and art galleries, including the Louvre, the Musée d'Orsay, and the Pompidou Center.
3. Enjoy the city's gardens and parks: Paris has several beautiful gardens and parks, including the Luxembourg Gardens, the Tuileries Garden, and the Champ de Mars.
4. Take a river cruise: A river cruise along the Seine is a great way to see the city's sights and learn about its history.
5. Shop on the Champs-Élysées: The famous avenue is lined with high-end boutiques, department stores, and cafes.
6. Attend a cabaret show: The Moulin Rouge is a famous cabaret that offers a lively and entertaining show.
7. Visit Montmartre: The charming neighborhood of Montmartre is known for its artistic vibe, street performers, and stunning views of the city.
8. Take a food tour: Paris is known for its culinary delights, and a food tour is a great way to sample the city's famous cuisine.
9. Visit the Palace of Versailles: The Palace of Versailles is a beautiful royal palace located just outside of Paris.
10. Attend a fashion show: Paris is the fashion capital of the world, and attending a fashion show is a great way to see the latest trends and styles.

These are just a few of the many things to do in Paris. The city has something for everyone, from history and culture to food and entertainment.

Questa è davvero una bella risposta!

Bonus: dai un suggerimento diverso

In definitiva, abbiamo a disposizione un trucco semplice ma efficace, soprattutto quando la verbosità del modello non è un problema. Possiamo menzionare esplicitamente nel prompt che richiediamo una risposta concisa. Ad esempio, quando chiedo alla modella: “Cosa posso fare a Parigi? Rispondi in cinque frasi o meno”, prevede:

Sure! Here are five things you can do in Paris: 1. Visit the Eiffel Tower, a iconic landmark that offers stunning views of the city. 2. Explore the Louvre Museum and see famous artworks like the Mona Lisa. 3. Stroll along the Seine River and enjoy the city's picturesque bridges and charming streets. 4. Visit the Notre-Dame Cathedral, a beautiful and historic church. 5. Enjoy a delicious French meal at a traditional bistro or restaurant.

È una risposta breve ma pulita e completa.

Criteri di arresto: un tentativo fallito

Per coloro che sono interessati, Hugging Face ha introdotto un’altra API chiamata StoppingCriteria, intesa a stabilire condizioni specifiche che costringono una sequenza a fermarsi. Tuttavia, quando si tratta di definire un criterio personalizzato che interrompa il modello quando incontra determinati token (ad esempio, “\n”), potrebbe non fornire una soluzione completa al problema. Ad esempio, ho tentato di creare una classe StopOnTokens:

# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if torch.eq(input_ids(0)(-len(stop_ids):), stop_ids).all():
return True
return False

stopping_criteria = StoppingCriteriaList((StopOnTokens()))

Tuttavia, il modello non riesce ancora a fornire una risposta completa.

In questo post del blog ho evidenziato il problema dell’arresto della generazione in Llama 2 e ho introdotto diverse soluzioni temporanee. Ancora una volta, tralascio molti dettagli delle implementazioni e ti consiglio di dare uno sguardo più approfondito al mio taccuino.

Immagine di José Aragones

Tuttavia, è importante notare che queste soluzioni hanno lo scopo di migliorare la facilità d’uso delle risposte a breve termine, ma stiamo aspettando con impazienza una soluzione permanente per risolvere questo problema.

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *