Un tutorial sulla messa a punto di MLLM utilizzando il nuovissimo modello Mini-InternVL tascabile

fotografato da Maarten van den Heuvel SU Unsplash

Il mondo dei modelli linguistici di grandi dimensioni (LLM) è in continua evoluzione, con nuovi progressi che emergono rapidamente. Un'area interessante è lo sviluppo di LLM multimodali (MLLM), in grado di comprendere e interagire sia con testi che con immagini. Ciò apre un mondo di possibilità per attività come la comprensione dei documenti, la risposta visiva alle domande e altro ancora.

Recentemente ho scritto un post generale su uno di questi modelli che puoi controllare qui:

Ma in questo esploreremo una potente combinazione: il modello InternVL e la tecnica di messa a punto QLoRA. Ci concentreremo su come possiamo personalizzare facilmente tali modelli per qualsiasi caso d'uso specifico. Utilizzeremo questi strumenti per creare una pipeline di comprensione delle ricevute che estrae informazioni chiave come il nome dell'azienda, l'indirizzo e l'importo totale dell'acquisto con elevata precisione.

Questo progetto mira a sviluppare un sistema in grado di estrarre con precisione informazioni specifiche dalle ricevute scansionate, utilizzando le capacità di InternVL. Il compito rappresenta una sfida unica, che richiede non solo una solida elaborazione del linguaggio naturale (NLP), ma anche la capacità di interpretare il layout visivo dell'immagine di input. Ciò ci consentirà di creare un'unica pipeline end-to-end senza OCR che dimostri una forte generalizzazione su documenti complessi.

Per addestrare e valutare il nostro modello, utilizzeremo il file SROIE set di dati. SROIE fornisce 1000 immagini di ricevute scansionate, ciascuna annotata con entità chiave come:

  • Azienda: il nome del negozio o dell'attività.
  • Data: la data di acquisto.
  • Indirizzo: l'indirizzo del negozio.
  • Totale: l'importo totale pagato.
Fonte: https://arxiv.org/pdf/2103.10213.pdf.

Valuteremo le prestazioni del nostro modello utilizzando un punteggio di somiglianza fuzzy, una metrica che misura la somiglianza tra le entità previste e quelle reali. Questa metrica varia da 0 (risultati irrilevanti) a 100 (previsioni perfette).

InternVL è una famiglia di LLM multimodali di OpenGVLab, progettata per eccellere in attività che coinvolgono immagini e testo. La sua architettura combina un modello di visione (come InternViT) con un modello linguistico (come InternLM2 o Phi-3). Ci concentreremo sulla variante Mini-InternVL-Chat-2B-V1–5, una versione più piccola adatta per l'esecuzione su GPU di livello consumer.

I principali punti di forza di InternVL:

  • Efficienza: le sue dimensioni compatte consentono formazione e inferenza efficienti.
  • Precisione: Nonostante sia più piccolo, raggiunge prestazioni competitive in vari benchmark.
  • Funzionalità multimodali: combina perfettamente la comprensione di immagini e testo.

Demo: puoi esplorare una demo dal vivo di InternVL Qui.

Per aumentare ulteriormente le prestazioni del nostro modello, utilizzeremo QLoRA, una tecnica di perfezionamento che riduce significativamente il consumo di memoria preservando le prestazioni. Ecco come funziona:

  1. Quantizzazione: l'LLM pre-addestrato viene quantizzato con una precisione a 4 bit, riducendone l'ingombro in memoria.
  2. Adattatori di basso rango (LoRA): invece di modificare tutti i parametri del modello pre-addestrato, LoRA aggiunge piccoli adattatori addestrabili alla rete. Questi adattatori acquisiscono informazioni specifiche sull'attività senza richiedere modifiche al modello principale.
  3. Formazione efficiente: la combinazione di quantizzazione e LoRA consente una messa a punto efficiente anche su GPU con memoria limitata.

Immergiamoci nel codice. Innanzitutto, valuteremo le prestazioni di base di Mini-InternVL-Chat-2B-V1–5 senza alcuna messa a punto:

quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

model = InternVLChatModel.from_pretrained(
args.path,
device_map={"": 0},
quantization_config=quant_config if args.quant else None,
torch_dtype=torch.bfloat16,
)

tokenizer = InternLM2Tokenizer.from_pretrained(args.path)
# set the max number of tiles in `max_num`

model.eval()

pixel_values = (
load_image(image_base_path / "X51005255805.jpg", max_num=6)
.to(torch.bfloat16)
.cuda()
)

generation_config = dict(
num_beams=1,
max_new_tokens=512,
do_sample=False,
)

# single-round single-image conversation
question = (
"Extract the company, date, address and total in json format."
"Respond with a valid JSON only."
)
# print(model)
response = model.chat(tokenizer, pixel_values, question, generation_config)

print(response)

Il risultato:

```json
{
"company": "SAM SAM TRADING CO",
"date": "Fri, 29-12-2017",
"address": "67, JLN MENHAW 25/63 TNN SRI HUDA, 40400 SHAH ALAM",
"total": "RM 14.10"
}
```

Questo codice:

  1. Carica il modello dall'hub Hugging Face.
  2. Carica un'immagine di ricevuta di esempio e la converte in un tensore.
  3. Formula una domanda chiedendo al modello di estrarre informazioni rilevanti dall'immagine.
  4. Esegue il modello e restituisce le informazioni estratte in formato JSON.

Questa valutazione zero-shot mostra risultati impressionanti, ottenendo un punteggio medio di somiglianza fuzzy di 74,24%. Ciò dimostra la capacità di InternVL di comprendere le ricevute ed estrarre informazioni senza alcuna regolazione.

Per aumentare ulteriormente la precisione, metteremo a punto il modello utilizzando QLoRA. Ecco come lo implementiamo:

_data = load_data(args.data_path, fold="train")

# Quantization Config
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

model = InternVLChatModel.from_pretrained(
path,
device_map={"": 0},
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)

tokenizer = InternLM2Tokenizer.from_pretrained(path)

# set the max number of tiles in `max_num`
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id

model.config.llm_config.use_cache = False

model = wrap_lora(model, r=128, lora_alpha=256)

training_data = SFTDataset(
data=_data, template=model.config.template, tokenizer=tokenizer
)

collator = CustomDataCollator(pad_token=tokenizer.pad_token_id, ignore_index=-100)

img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id
print("model.img_context_token_id", model.img_context_token_id)

train_params = TrainingArguments(
output_dir=str(BASE_PATH / "results_modified"),
num_train_epochs=EPOCHS,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
optim="paged_adamw_32bit",
save_steps=len(training_data) // 10,
logging_steps=len(training_data) // 50,
learning_rate=5e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.001,
max_steps=-1,
group_by_length=False,
max_grad_norm=1.0,
)
# Trainer
fine_tuning = SFTTrainer(
model=model,
train_dataset=training_data,
dataset_text_field="###",
tokenizer=tokenizer,
args=train_params,
data_collator=collator,
max_seq_length=tokenizer.model_max_length,
)

print(fine_tuning.model.print_trainable_parameters())
# Training
fine_tuning.train()
# Save Model
fine_tuning.model.save_pretrained(refined_model)

Questo codice:

  1. Carica il modello con la quantizzazione abilitata.
  2. Avvolge il modello con LoRA, aggiungendo adattatori addestrabili.
  3. Crea un set di dati dal set di dati SROIE.
  4. Definisce gli argomenti di training come la velocità di apprendimento, la dimensione del batch e le epoche.
  5. Inizializza un trainer per gestire il processo di formazione.
  6. Addestra il modello sul set di dati SROIE.
  7. Salva il modello ottimizzato.

Ecco un esempio di confronto tra il modello base e il modello ottimizzato QLoRA:

Ground Truth: 

{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2018",
"address": "NO 4,JALAN PERJIRANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR.",
"total": "72.00"
}

Prediction Base: KO

```json
{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2016",
"address": "JM092487-D",
"total": "67.92"
}
```

Prediction QLoRA: OK

{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2018",
"address": "NO 4, JALAN PERUBANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR",
"total": "72.00"
}

Dopo la messa a punto con QLoRA, il nostro modello raggiunge risultati notevoli 95,4% punteggio di somiglianza fuzzy, un miglioramento significativo rispetto alla prestazione di base (74,24%). Ciò dimostra la potenza di QLoRA nell'aumentare la precisione del modello senza richiedere ingenti risorse di elaborazione (formazione di 15 minuti su 600 campioni su una GPU RTX 3080).

Abbiamo creato con successo una solida pipeline per la comprensione delle ricevute utilizzando InternVL e QLoRA. Questo approccio mostra il potenziale dei LLM multimodali per attività del mondo reale come l'analisi dei documenti e l'estrazione di informazioni. In questo caso d'uso di esempio, abbiamo guadagnato 30 punti nella qualità della previsione utilizzando alcune centinaia di esempi e pochi minuti di tempo di elaborazione su una GPU consumer.

Puoi trovare l'implementazione completa del codice per questo progetto Qui.

Lo sviluppo dei LLM multimodali è solo all’inizio e il futuro riserva interessanti possibilità. L’area dell’elaborazione automatizzata dei documenti ha un potenziale immenso nell’era dei MLLM. Questi modelli possono rivoluzionare il modo in cui estraiamo informazioni da contratti, fatture e altri documenti, richiedendo dati di formazione minimi. Integrando testo e visione, possono analizzare il layout di documenti complessi con una precisione senza precedenti, aprendo la strada a una gestione delle informazioni più efficiente e intelligente.

Il futuro dell'intelligenza artificiale è multimodale e InternVL e QLoRA sono strumenti potenti per aiutarci a sbloccare il suo potenziale con un budget di elaborazione ridotto.

Collegamenti:

Codice: https://github.com/CVxTz/doc-llm

Origine del set di dati: https://rrc.cvc.uab.es/?ch=13&com=introduzione
Licenza set di dati: concessa in licenza con a Licenza Internazionale Creative Commons Attribuzione 4.0.

Fonte: towardsdatascience.com

Lascia un commento

Il tuo indirizzo email non sarà pubblicato. I campi obbligatori sono contrassegnati *