İçeriğe geç

Llama 3.2 Vision 11B / 90B: Cross-Attention Adapter + Multi-Image FT

Llama 3.2 Vision — Meta'nın cross-attention adapter yaklaşımı (LLaVA MLP'sinden farklı). Vision encoder ViT-H/14, LLM ile **interleaved cross-attention layers** ile birleşir. Multi-image FT, image+text interleave format, RTX 4090'da 11B QLoRA marjinal (~22 GB), 90B cloud only.

Şükrü Yusuf KAYA
30 dakikalık okuma
İleri
Llama 3.2 Vision 11B / 90B: Cross-Attention Adapter + Multi-Image FT
python
# === Llama 3.2 11B Vision QLoRA (RTX 4090) ===
from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
import torch
 
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16)
 
model = MllamaForConditionalGeneration.from_pretrained(
"meta-llama/Llama-3.2-11B-Vision-Instruct",
quantization_config=bnb,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct")
 
# LoRA — cross-attn adapter + text LLM target
lora = LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora)
 
# Visual instruction dataset
dataset = load_dataset("HuggingFaceM4/the_cauldron", "vqav2", split="train[:5000]")
# Format: {"images": [PIL], "texts": [{"user": "...", "assistant": "..."}]}
 
def format_vlm(example):
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": example["texts"][0]["user"]},
]},
{"role": "assistant", "content": example["texts"][0]["assistant"]},
]
inputs = processor(
text=processor.apply_chat_template(messages, tokenize=False),
images=example["images"],
return_tensors="pt", padding=True,
)
return inputs
 
# Train (~6 saat 5000 sample, RTX 4090)
cfg = SFTConfig(
output_dir="llama-3.2-11b-vision-tr",
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=1e-4,
bf16=True, optim="paged_adamw_8bit",
max_seq_length=4096,
logging_steps=5, report_to="wandb",
)
Llama 3.2 11B Vision QLoRA — RTX 4090
✅ Teslim
  1. Llama 3.2 11B Vision mini-FT. 2) Pre/post inference karşılaştır. 3) Sonraki ders: 6.4 — Qwen 2.5-VL Dynamic Resolution + TR OCR.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler