Una tecnica di messa a punto unificata più economica e veloce

Immagine generata con DALL-E 3 dall'autore

ORPO è un nuova entusiasmante tecnica di messa a punto che combina le tradizionali fasi di fine tuning supervisionato e di allineamento delle preferenze in un unico processo. Ciò riduce le risorse computazionali e il tempo necessari per la formazione. Inoltre, i risultati empirici dimostrano che ORPO supera gli altri metodi di allineamento su varie dimensioni e parametri di riferimento del modello.

In questo articolo metteremo a punto il nuovo modello Llama 3 8B utilizzando ORPO con la libreria TRL. Il codice è disponibile su Google Co e nel Corso LLM su GitHub.

L'ottimizzazione delle istruzioni e l'allineamento delle preferenze sono tecniche essenziali per adattare i Large Language Models (LLM) a compiti specifici. Tradizionalmente, ciò comporta un processo in più fasi: 1/ Messa a punto supervisionata (SFT) sulle istruzioni per adattare il modello al dominio target, seguito da 2/ metodi di allineamento delle preferenze come l'apprendimento per rinforzo con feedback umano (RLHF) o l'ottimizzazione delle preferenze dirette (DPO) per aumentare la probabilità di generare risposte preferite rispetto a quelle rifiutate.

Immagine dell'autore

Tuttavia, i ricercatori hanno identificato una limitazione in questo approccio. Sebbene SFT adatti efficacemente il modello al dominio desiderato, lo fa inavvertitamente aumenta la probabilità di generare risposte indesiderate accanto a quelli preferiti. Questo è il motivo per cui la fase di allineamento delle preferenze è necessaria per ampliare il divario tra le probabilità dei risultati preferiti e di quelli rifiutati.

Si noti come la probabilità di risposte rifiutate aumenta durante la messa a punto supervisionata (immagine dal documento ORPO).

Presentato da Hong e Lee (2024)ORPO offre una soluzione elegante a questo problema combinando l'ottimizzazione delle istruzioni e l'allineamento delle preferenze in un unico processo di formazione monolitico. ORPO modifica l'obiettivo standard della modellazione del linguaggio, combinando la perdita di probabilità logaritmica negativa con un termine OR (odds ratio). Questa perdita di OR penalizza debolmente le risposte rifiutate mentre premia fortemente quelle preferite, consentendo al modello di apprendere contemporaneamente il compito target e di allinearsi con le preferenze umane.

ORPO è stato implementato nelle principali librerie di fine-tuning, come TRL, AxolotlE LLaMA-Fabbrica. Nella prossima sezione vedremo come utilizzare con TRL.

Lama 3 è l'ultima famiglia di LLM sviluppata da Meta. I modelli sono stati addestrati su un ampio set di dati di 15 trilioni di gettoni (rispetto ai gettoni 2T per Lama 2). Sono state rilasciate due dimensioni del modello: un modello da 70 miliardi di parametri e un modello più piccolo da 8 miliardi di parametri. Il modello 70B ha già dimostrato prestazioni impressionanti, ottenendo un punteggio di 82 nel benchmark MMLU e 81,7 nel benchmark HumanEval.

I modelli Llama 3 hanno inoltre aumentato la lunghezza del contesto fino a 8.192 token (4.096 token per Llama 2) e potenzialmente scalano fino a 32k con RoPE. Inoltre, i modelli utilizzano un nuovo tokenizzatore con un vocabolario di 128.000 token, riducendo il numero di token richiesti per codificare il testo del 15%. Questo vocabolario spiega anche il passaggio dai parametri 7B a 8B.

Campioni da ORPO-DPO-mix-40k (immagine dell'autore).

ORPO richiede un set di dati sulle preferenze, incluso un prompt, una risposta scelta e una risposta rifiutata. In questo esempio useremo mlabonne/orpo-dpo-mix-40kuna combinazione dei seguenti set di dati DPO di alta qualità:

Grazie a argilla, disallineamento, M4-aiE jondurbin per fornire i set di dati di origine.

Come al solito, iniziamo installando le librerie richieste:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

Una volta installato, possiamo importare le librerie necessarie e accedere a W&B (opzionale):

import gc
import os

import torch
import wandb
from datasets import load_dataset
from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

Se hai una GPU recente, dovresti essere in grado di utilizzare anche il file Libreria Flash Attenzione per sostituire l'implementazione predefinita dell'attenzione desiderosa con una più efficiente.

if torch.cuda.get_device_capability()(0) >= 8:
!pip install -qqq flash-attn
attn_implementation = "flash_attention_2"
torch_dtype = torch.bfloat16
else:
attn_implementation = "eager"
torch_dtype = torch.float16

Di seguito caricheremo il modello Llama 3 8B con precisione a 4 bit grazie a bitsandbytes. Impostiamo quindi la configurazione LoRA utilizzando PEFT per QLoRA. Sto anche usando il conveniente setup_chat_format() funzione per modificare il modello e il tokenizzatore per ChatML supporto. Applica automaticamente questo modello di chat, aggiunge token speciali e ridimensiona il livello di incorporamento del modello per adattarlo alla nuova dimensione del vocabolario.

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"

# QLoRA config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=('up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj')
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Load model
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

Ora che il modello è pronto per l'addestramento, possiamo occuparci del set di dati. Carichiamo mlabonne/orpo-dpo-mix-40k e utilizzare il apply_chat_template() funzione per convertire le colonne “scelto” e “rifiutato” nel formato ChatML. Tieni presente che sto utilizzando solo 1.000 campioni e non l'intero set di dati, poiché l'esecuzione richiederebbe troppo tempo.

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(10))

def format_chat_template(row):
row("chosen") = tokenizer.apply_chat_template(row("chosen"), tokenize=False)
row("rejected") = tokenizer.apply_chat_template(row("rejected"), tokenize=False)
return row

dataset = dataset.map(
format_chat_template,
num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

Innanzitutto, dobbiamo impostare alcuni iperparametri:

  • learning_rate: ORPO utilizza tassi di apprendimento molto bassi rispetto alla tradizionale SFT o addirittura al DPO. Questo valore di 8e-6 deriva dal documento originale e corrisponde approssimativamente a un tasso di apprendimento SFT di 1e-5 e a un tasso di apprendimento DPO di 5e-6. Consiglierei di aumentarlo intorno a 1e-6 per una vera messa a punto.
  • beta: È il parametro $\lambda$ presente nel documento, con valore predefinito pari a 0,1. Un'appendice dell'articolo originale mostra come è stato selezionato con uno studio di ablazione.
  • Altri parametri, come max_length e le dimensioni del batch sono impostate per utilizzare tutta la VRAM disponibile (~20 GB in questa configurazione). Idealmente, addestreremmo il modello per 3-5 epoche, ma qui ci limiteremo a 1.

Infine, possiamo addestrare il modello utilizzando ORPOTrainer, che funge da wrapper.

orpo_args = ORPOConfig(
learning_rate=8e-6,
beta=0.1,
lr_scheduler_type="linear",
max_length=1024,
max_prompt_length=512,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
optim="paged_adamw_8bit",
num_train_epochs=1,
evaluation_strategy="steps",
eval_steps=0.2,
logging_steps=1,
warmup_steps=10,
report_to="wandb",
output_dir="./results/",
)

trainer = ORPOTrainer(
model=model,
args=orpo_args,
train_dataset=dataset("train"),
eval_dataset=dataset("test"),
peft_config=peft_config,
tokenizer=tokenizer,
)

trainer.train()
trainer.save_model(new_model)

L'addestramento del modello su questi 1.000 campioni ha richiesto circa 2 ore su una GPU L4. Controlliamo i grafici W&B:

Anche se la perdita diminuisce, la differenza tra le risposte scelte e quelle rifiutate non è chiara: il margine medio e la precisione sono solo leggermente superiori a zero e 0,5 rispettivamente.

Nell'articolo originale, gli autori hanno addestrato modelli sul Anthropic/hh-rlhf set di dati (161.000 campioni) per 10 epoche, che è molto più lungo della nostra esecuzione rapida. Hanno anche sperimentato Llama 3 e gentilmente condiviso i loro registri con me (grazie Jiwoo Hong).

Per terminare questo tutorial, uniamo l'adattatore QLoRA con il modello base e spingiamolo su Hugging Face Hub.

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()

# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)

# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()
model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

Congratulazioni, abbiamo terminato questa rapida messa a punto di Llama 3: mlabonne/OrpoLlama-3–8B. Puoi giocarci usando questo Abbracciare lo spazio del viso (ecco un taccuino per crearne uno tuo). Sebbene il modello sia sottodimensionato, come evidenziato dalle curve W&B, ho eseguito alcune valutazioni sulla suite di benchmark di Nous utilizzando Valutazione automatica LLM.

La nostra messa a punto ORPO è in realtà abbastanza decente e migliora le prestazioni del modello base su ogni benchmark. Ciò è incoraggiante e probabilmente significa che una messa a punto di tutti i 40.000 campioni produrrebbe ottimi risultati.

Questo è un momento entusiasmante per la comunità open source, con il rilascio di sempre più modelli open-weight di alta qualità. Il divario tra i modelli closed source e open-weight si sta lentamente riducendo e la messa a punto è uno strumento essenziale per ottenere le migliori prestazioni per i tuoi casi d'uso.

Immagine dell'autore

In questo articolo, abbiamo introdotto l'algoritmo ORPO e spiegato come unifica le fasi SFT e di allineamento delle preferenze in un unico processo. Quindi, abbiamo utilizzato TRL per mettere a punto un modello Llama 3 8B su un set di dati delle preferenze personalizzate. Il modello finale mostra risultati incoraggianti ed evidenzia il potenziale di ORPO come nuovo paradigma di messa a punto.

Spero che sia stato utile e consiglio di eseguire il file Taccuino di Colab per mettere a punto i tuoi modelli Llama 3. Nei prossimi articoli vedremo come creare set di dati di alta qualità, un punto che spesso viene trascurato. Se ti è piaciuto questo articolo, seguimi su Volto che abbraccia e Twitter @maximelabonne.

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *