Estrazione delle relazioni con i modelli Llama3 |  di Silvia Onofrei |  Aprile 2024

 | Intelligenza-Artificiale

Estrazione migliorata delle relazioni ottimizzando Llama3–8B con un set di dati sintetici creato utilizzando Llama3–70B

Generato con DALL-E.

L'estrazione delle relazioni (RE) è il compito di estrarre le relazioni dal testo non strutturato per identificare le connessioni tra varie entità denominate. Viene eseguito insieme al riconoscimento dell'entità denominata (NER) ed è un passaggio essenziale in una pipeline di elaborazione del linguaggio naturale. Con l'avvento dei Large Language Models (LLM), gli approcci tradizionali supervisionati che comportano l'etichettatura degli intervalli di entità e la classificazione delle relazioni (se presenti) tra di essi vengono migliorati o completamente sostituiti da approcci basati su LLM (1).

Llama3 è la versione principale più recente nel dominio di GenerativeAI (2). Il modello base è disponibile in due dimensioni, 8B e 70B, con un modello 400B previsto a breve. Questi modelli sono disponibili sulla piattaforma HuggingFace; Vedere (3) per dettagli. La variante 70B alimenta il nuovo sito di chat di Meta Meta.ai e mostra prestazioni paragonabili a ChatGPT. Il modello 8B è tra i più performanti della sua categoria. L'architettura di Llama3 è simile a quella di Llama2, con l'aumento delle prestazioni dovuto principalmente all'aggiornamento dei dati. Il modello viene fornito con un tokenizzatore aggiornato e una finestra di contesto estesa. È etichettato come open source, sebbene venga rilasciata solo una piccola percentuale dei dati. Nel complesso è un modello eccellente e non vedo l'ora di provarlo.

Llama3–70B può produrre risultati sorprendenti, ma a causa delle sue dimensioni è poco pratico, proibitivamente costoso e difficile da usare sui sistemi locali. Pertanto, per sfruttare le sue capacità, abbiamo chiesto a Llama3–70B di insegnare al più piccolo Llama3–8B il compito di estrarre le relazioni dal testo non strutturato.

Nello specifico, con l'aiuto di Llama3–70B, costruiamo un set di dati supervisionato per l'ottimizzazione mirata all'estrazione delle relazioni. Utilizziamo quindi questo set di dati per ottimizzare Llama3–8B per migliorare le sue capacità di estrazione delle relazioni.

Per riprodurre il codice nel file Taccuino di Google Colab associato a questo blog, avrai bisogno di:

  • Credenziali HuggingFace (per salvare il modello messo a punto, opzionale) e accesso a Llama3, ottenibile seguendo le istruzioni di una delle schede dei modelli;
  • Un libero GroqCloud account (puoi accedere con un account Google) e una chiave API corrispondente.

Per questo progetto ho utilizzato un Google Colab Pro dotato di GPU A100 e impostazione High-RAM.

Iniziamo installando tutte le librerie richieste:

!pip install -q groq
!pip install -U accelerate bitsandbytes datasets evaluate
!pip install -U peft transformers trl

Mi ha fatto molto piacere notare che l'intera configurazione ha funzionato dall'inizio senza problemi di dipendenze o necessità di installazione transformers dalla fonte, nonostante la novità del modello.

Dobbiamo anche consentire l'accesso a Goggle Colab all'unità e ai file e impostare la directory di lavoro:

# For Google Colab settings
from google.colab import userdata, drive

# This will prompt for authorization
drive.mount('/content/drive')

# Set the working directory
%cd '/content/drive/MyDrive/postedBlogs/llama3RE'

Per coloro che desiderano caricare il modello su HuggingFace Hub, dobbiamo caricare le credenziali dell'Hub. Nel mio caso, questi sono archiviati nei segreti di Google Colab, a cui è possibile accedere tramite il pulsante chiave a sinistra. Questo passaggio è facoltativo.

# For Hugging Face Hub setting
from huggingface_hub import login

# Upload the HuggingFace token (should have WRITE access) from Colab secrets
HF = userdata.get('HF')

# This is needed to upload the model to HuggingFace
login(token=HF,add_to_git_credential=True)

Ho anche aggiunto alcune variabili di percorso per semplificare l'accesso ai file:

# Create a path variable for the data folder
data_path = '/content/drive/MyDrive/postedBlogs/llama3RE/datas/'

# Full fine-tuning dataset
sft_dataset_file = f'{data_path}sft_train_data.json'

# Data collected from the the mini-test
mini_data_path = f'{data_path}mini_data.json'

# Test data containing all three outputs
all_tests_data = f'{data_path}all_tests.json'

# The adjusted training dataset
train_data_path = f'{data_path}sft_train_data.json'

# Create a path variable for the SFT model to be saved locally
sft_model_path = '/content/drive/MyDrive/llama3RE/Llama3_RE/'

Ora che il nostro spazio di lavoro è configurato, possiamo passare al primo passaggio, ovvero creare un set di dati sintetico per l'attività di estrazione delle relazioni.

Sono disponibili diversi set di dati per l'estrazione delle relazioni, il più noto è il CoNLL04 set di dati. Inoltre, ci sono set di dati eccellenti come web_nlgdisponibile su HuggingFace e SciREX sviluppato da AllenAI. Tuttavia, la maggior parte di questi set di dati prevede licenze restrittive.

Ispirato al formato del web_nlg set di dati costruiremo il nostro set di dati. Questo approccio sarà particolarmente utile se intendiamo mettere a punto un modello addestrato sul nostro set di dati. Per iniziare, abbiamo bisogno di una raccolta di frasi brevi per il nostro compito di estrazione delle relazioni. Possiamo compilare questo corpus in vari modi.

Raccogli una raccolta di frasi

Noi useremo databricks-dolly-15kun set di dati open source generato dai dipendenti di Databricks nel 2023. Questo set di dati è progettato per la messa a punto supervisionata e include quattro funzionalità: istruzioni, contesto, risposta e categoria. Dopo aver analizzato le otto categorie, ho deciso di mantenere la prima frase del contesto della information_extraction categoria. I passaggi di analisi dei dati sono descritti di seguito:

from datasets import load_dataset

# Load the dataset
dataset = load_dataset("databricks/databricks-dolly-15k")

# Choose the desired category from the dataset
ie_category = (e for e in dataset("train") if e("category")=="information_extraction")

# Retain only the context from each instance
ie_context = (e("context") for e in ie_category)

# Split the text into sentences (at the period) and keep the first sentence
reduced_context = (text.split('.')(0) + '.' for text in ie_context)

# Retain sequences of specified lengths only (use character length)
sampler = (e for e in reduced_context if 30 < len(e) < 170)

Il processo di selezione produce un set di dati comprendente 1.041 frasi. Dato che si tratta di un mini-progetto, non ho selezionato manualmente le frasi e, di conseguenza, alcuni esempi potrebbero non essere ideali per il nostro compito. In un progetto destinato alla produzione, selezionerei attentamente solo le frasi più appropriate. Tuttavia, per gli scopi di questo progetto, questo set di dati sarà sufficiente.

Formattare i dati

Dobbiamo prima creare un messaggio di sistema che definirà il prompt di input e istruirà il modello su come generare le risposte:

system_message = """You are an experienced annontator. 
Extract all entities and the relations between them from the following text.
Write the answer as a triple entity1|relationship|entitity2.
Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
"""

Poiché si tratta di una fase sperimentale, mantengo le esigenze del modello al minimo. Ho testato diversi altri prompt, inclusi alcuni che richiedevano output in formato CoNLL in cui le entità vengono categorizzate, e il modello ha funzionato abbastanza bene. Tuttavia, per semplicità, per ora ci limiteremo alle nozioni di base.

Dobbiamo anche convertire i dati in un formato conversazionale:

messages = ((
{"role": "system","content": f"{system_message}"},
{"role": "user", "content": e}) for e in sampler)

Il client e l'API Groq

Llama3 è stato rilasciato solo pochi giorni fa e la disponibilità delle opzioni API è ancora limitata. Sebbene sia disponibile un'interfaccia di chat per Llama3–70B, questo progetto richiede un'API in grado di elaborare le mie 1.000 frasi con un paio di righe di codice. L'ho trovato eccellente Video Youtube che spiega come utilizzare gratuitamente l'API GroqCloud. Per maggiori dettagli fare riferimento al video.

Solo un promemoria: dovrai accedere e recuperare una chiave API gratuita dal GroqCloud sito web. La mia chiave API è già salvata nei segreti di Google Colab. Iniziamo inizializzando il client Groq:

import os
from groq import Groq

gclient = Groq(
api_key=userdata.get("GROQ"),
)

Successivamente dobbiamo definire un paio di funzioni di supporto che ci consentiranno di interagire con il file Meta.ai interfaccia di chat in modo efficace (questi sono adattati da Video Youtube):

import time
from tqdm import tqdm

def process_data(prompt):

"""Send one request and retrieve model's generation."""

chat_completion = gclient.chat.completions.create(
messages=prompt, # input prompt to send to the model
model="llama3-70b-8192", # according to GroqCloud labeling
temperature=0.5, # controls diversity
max_tokens=128, # max number tokens to generate
top_p=1, # proportion of likelihood weighted options to consider
stop=None, # string that signals to stop generating
stream=False, # if set partial messages are sent
)
return chat_completion.choices(0).message.content

def send_messages(messages):

"""Process messages in batches with a pause between batches."""

batch_size = 10
answers = ()

for i in tqdm(range(0, len(messages), batch_size)): # batches of size 10

batch = messages(i:i+10) # get the next batch of messages

for message in batch:
output = process_data(message)
answers.append(output)

if i + 10 < len(messages): # check if there are batches left
time.sleep(10) # wait for 10 seconds

return answers

La prima funzione process_data() funge da wrapper per la funzione di completamento della chat del client Groq. La seconda funzione send_messages()elabora i dati in piccoli batch. Se segui il collegamento Impostazioni nella pagina del parco giochi Groq, troverai un collegamento a Limiti che descrive in dettaglio le condizioni alle quali possiamo utilizzare l'API gratuita, inclusi i limiti al numero di richieste e token generati. Per evitare di superare questi limiti, ho aggiunto un ritardo di 10 secondi dopo ogni batch di 10 messaggi, anche se nel mio caso non era strettamente necessario. Potresti voler sperimentare queste impostazioni.

Ciò che resta ora è generare i nostri dati di estrazione delle relazioni e integrarli con il set di dati iniziale:

# Data generation with Llama3-70B
answers = send_messages(messages)

# Combine input data with the generated dataset
combined_dataset = ({'text': user, 'gold_re': output} for user, output in zip(sampler, answers))

Prima di procedere con la messa a punto del modello, è importante valutare le sue prestazioni su diversi campioni per determinare se la messa a punto è effettivamente necessaria.

Creazione di un set di dati di test

Selezioneremo 20 campioni dal set di dati che abbiamo appena costruito e li metteremo da parte per i test. Il resto del set di dati verrà utilizzato per la messa a punto.

import random
random.seed(17)

# Select 20 random entries
mini_data = random.sample(combined_dataset, 20)

# Build conversational format
parsed_mini_data = (({'role': 'system', 'content': system_message},
{'role': 'user', 'content': e('text')}) for e in mini_data)

# Create the training set
train_data = (item for item in combined_dataset if item not in mini_data)

Utilizzeremo l'API GroqCloud e le utilità sopra definite, specificando model=llama3-8b-8192 mentre il resto della funzione rimane invariato. In questo caso, possiamo elaborare direttamente il nostro piccolo set di dati senza preoccuparci di superare i limiti API.

Ecco un output di esempio che fornisce l'originale textdenotata la generazione Llama3-70B gold_re e la generazione Llama3-8B etichettata test_re.

{'text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.',
'gold_re': 'people|were aware of|shocks\nshocks|from|electric fish\nelectric fish|had|electricity',
'test_re': 'electric fish|were aware of|shocks'}

Per il set di dati completo del test, fare riferimento a Taccuino di Google Colab.

Proprio da questo esempio diventa chiaro che Llama3–8B potrebbe beneficiare di alcuni miglioramenti nelle sue capacità di estrazione delle relazioni. Lavoriamo per migliorarlo.

Utilizzeremo un arsenale completo di tecniche per assisterci, tra cui QLoRA e Flash Attention. Non approfondirò qui i dettagli della scelta degli iperparametri, ma se sei interessato a esplorare ulteriormente, dai un'occhiata a questi ottimi riferimenti (4) E (5).

La GPU A100 supporta Flash Attention e bfloat16 e possiede circa 40 GB di memoria, sufficienti per le nostre esigenze di messa a punto.

Preparazione del set di dati SFT

Iniziamo analizzando il set di dati in un formato conversazionale, incluso un messaggio di sistema, testo di input e la risposta desiderata, che ricaviamo dalla generazione Llama3–70B. Lo salviamo quindi come set di dati HuggingFace:

def create_conversation(sample):
return {
"messages": (
{"role": "system","content": system_message},
{"role": "user", "content": sample("text")},
{"role": "assistant", "content": sample("gold_re")}
)
}

from datasets import load_dataset, Dataset

train_dataset = Dataset.from_list(train_data)

# Transform to conversational format
train_dataset = train_dataset.map(create_conversation,
remove_columns=train_dataset.features,
batched=False)

Scegli il Modello

model_id  =  "meta-llama/Meta-Llama-3-8B"

Carica il tokenizzatore

from transformers import AutoTokenizer

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id,
use_fast=True,
trust_remote_code=True)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

# Set a maximum length
tokenizer.model_max_length = 512

Scegli Parametri di quantizzazione

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

Carica il modello

from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training
from trl import setup_chat_format

device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
attn_implementation="flash_attention_2",
quantization_config=bnb_config
)

model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

Configurazione LoRA

from peft import LoraConfig

# According to Sebastian Raschka findings
peft_config = LoraConfig(
lora_alpha=128, #32
lora_dropout=0.05,
r=256, #16
bias="none",
target_modules=("q_proj", "o_proj", "gate_proj", "up_proj",
"down_proj", "k_proj", "v_proj"),
task_type="CAUSAL_LM",
)

I migliori risultati si ottengono quando si prendono di mira tutti gli strati lineari. Se i vincoli di memoria sono un problema, può essere utile optare per valori più standard come alpha=32 e ranking=16, poiché queste impostazioni comportano un numero significativamente inferiore di parametri.

Argomenti di formazione

from transformers import TrainingArguments

# Adapted from Phil Schmid blogpost
args = TrainingArguments(
output_dir=sft_model_path, # directory to save the model and repository id
num_train_epochs=2, # number of training epochs
per_device_train_batch_size=4, # batch size per device during training
gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory, use in distributed training
optim="adamw_8bit", # choose paged_adamw_8bit if not enough memory
logging_steps=10, # log every 10 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to Hugging Face hub
hub_model_id="llama3-8b-sft-qlora-re",
report_to="tensorboard", # report metrics to tensorboard
)

Se scegli di salvare il modello localmente, puoi omettere gli ultimi tre parametri. Potrebbe anche essere necessario regolare il per_device_batch_size E gradient_accumulation_steps per evitare errori di memoria esaurita (OOM).

Inizializza il trainer e addestra il modello

from trl import SFTTrainer

trainer = SFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset,
peft_config=peft_config,
max_seq_length=512,
tokenizer=tokenizer,
packing=False, # True if the dataset is large
dataset_kwargs={
"add_special_tokens": False, # the template adds the special tokens
"append_concat_token": False, # no need to add additional separator token
}
)

trainer.train()
trainer.save_model()

La formazione, incluso il salvataggio del modello, ha richiesto circa 10 minuti.

Cancellamo la memoria per prepararci ai test di inferenza. Se utilizzi una GPU con meno memoria e riscontri errori CUDA Out of Memory (OOM), potrebbe essere necessario riavviare il runtime.

import torch
import gc
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()

In questo passaggio finale caricheremo il modello base in mezza precisione insieme all'adattatore Peft. Per questo test ho scelto di non unire il modello con l'adattatore.

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
import torch

# HF model
peft_model_id = "solanaO/llama3-8b-sft-qlora-re"

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16,
offload_buffers=True
)

Successivamente, carichiamo il tokenizzatore:

okenizer = AutoTokenizer.from_pretrained(peft_model_id)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

E costruiamo la pipeline di generazione del testo:

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Carichiamo il dataset di test, composto dai 20 campioni che abbiamo messo da parte in precedenza, e formattiamo i dati in stile conversazionale. Tuttavia, questa volta omettiamo il messaggio dell'assistente e lo formattiamo come set di dati Hugging Face:

def create_input_prompt(sample):
return {
"messages": (
{"role": "system","content": system_message},
{"role": "user", "content": sample("text")},
)
}

from datasets import Dataset

test_dataset = Dataset.from_list(mini_data)

# Transform to conversational format
test_dataset = test_dataset.map(create_input_prompt,
remove_columns=test_dataset.features,
batched=False)

Un test campione

Generiamo l'output di estrazione della relazione utilizzando SFT Llama3–8B e confrontiamolo con i due output precedenti su una singola istanza:

 Generate the input prompt
prompt = pipe.tokenizer.apply_chat_template(test_dataset(2)("messages")(:2),
tokenize=False,
add_generation_prompt=True)
# Generate the output
outputs = pipe(prompt,
max_new_tokens=128,
do_sample=False,
temperature=0.1,
top_k=50,
top_p=0.1,
)
# Display the results
print(f"Question: {test_dataset(2)('messages')(1)('content')}\n")
print(f"Gold-RE: {test_sampler(2)('gold_re')}\n")
print(f"LLama3-8B-RE: {test_sampler(2)('test_re')}\n")
print(f"SFT-Llama3-8B-RE: {outputs(0)('generated_text')(len(prompt):).strip()}")

Otteniamo quanto segue:

Question: Long before any knowledge of electricity existed, people were aware of shocks from electric fish.

Gold-RE: people|were aware of|shocks
shocks|from|electric fish
electric fish|had|electricity

LLama3-8B-RE: electric fish|were aware of|shocks

SFT-Llama3-8B-RE: people|were aware of|shocks
shocks|from|electric fish

In questo esempio, osserviamo miglioramenti significativi nelle capacità di estrazione delle relazioni di Llama3–8B attraverso la messa a punto. Nonostante il set di dati di messa a punto non sia né molto pulito né particolarmente ampio, i risultati sono impressionanti.

Per i risultati completi sul set di dati di 20 campioni, fare riferimento a Taccuino di Google Colab. Tieni presente che il test di inferenza richiede più tempo perché carichiamo il modello con mezza precisione.

In conclusione, utilizzando Llama3–70B e un set di dati disponibile, abbiamo creato con successo un set di dati sintetico che è stato poi utilizzato per mettere a punto Llama3–8B per un compito specifico. Questo processo non solo ci ha fatto familiarizzare con Llama3, ma ci ha anche permesso di applicare le semplici tecniche di Hugging Face. Abbiamo osservato che lavorare con Llama3 somiglia molto all'esperienza con Llama2, con notevoli miglioramenti che riguardano una migliore qualità dell'output e un tokenizzatore più efficace.

Per coloro che sono interessati a spingersi oltre i confini, si consideri la possibilità di sfidare il modello con compiti più complessi come la categorizzazione di entità e relazioni e l'utilizzo di queste classificazioni per costruire un grafico della conoscenza.

  1. Somin Wadhwa, Silvio Amir, Byron C. Wallace, Revisiting Relation Extraction in the era of Large Language Models, arXiv.2305.05003 (2023).
  2. Meta, presentazione di Meta Llama 3: il LLM più capace disponibile fino ad oggi, 18 aprile 2024 (collegamento).
  3. Philipp Schmid, Omar Sanseviero, Pedro Cuenca, Youndes Belkada, Leandro von Werra, Benvenuto Llama 3: il nuovo LLM aperto di Met, 18 aprile 2024.
  4. Sebastiano Raschka, Suggerimenti pratici per ottimizzare gli LLM utilizzando LoRA (adattamento di basso rango)In vista dell'IA, 19 novembre 2023.
  5. Filippo Schmid, Come perfezionare gli LLM nel 2024 con Hugging Face, 22 gennaio 2024.

databricks-dolly-15K sulla piattaforma Hugging Face (CC BY-SA 3.0)

Deposito Github

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *