İçeriğe geç

Long-Context Dataset Engineering: NIAH, RULER ve 128K Context FT İçin Veri Hazırlama

Llama 3.1'in 128K context'ini gerçekten kullanmak: long-context SFT data nasıl üretilir? NIAH (Needle-in-Haystack) synthetic, RULER benchmark üretim reçeteleri, long-form QA dataset, code-repo concatenation, repository-level context. RTX 4090'da long-context QLoRA (128K seq) — packing dahil 22GB peak.

Şükrü Yusuf KAYA
32 dakikalık okuma
İleri
Long-Context Dataset Engineering: NIAH, RULER ve 128K Context FT İçin Veri Hazırlama
🎯 Niye long-context FT?
Llama 3.1 8B'nin context window'u 128K — ama varsayılan SFT data'sı 2K-8K. Model 128K'da "distraction" ve "middle loss" gösterir. Long-context FT modeli uzun input'ta tutarlı tutar.

1. NIAH (Needle in Haystack) Synthetic Data#

Format: Uzun bir doküman (50K-128K token, irrelevant content) + bir 'needle' cümle gizli + question.
[Long doc with random text...] [At position 87,234: "Magic sayı 47291'dir."] [More random text...] Soru: Magic sayı nedir? Cevap: 47291.
Bu 'recall' görevi — model 128K'da bilgiyi bulabiliyor mu?
def make_niah(haystack_text, needle, question, position_pct=0.5): """Needle'ı haystack'in belli bir konumuna gömer.""" tokens = haystack_text.split() n = len(tokens) insert_at = int(n * position_pct) tokens.insert(insert_at, needle) long_doc = " ".join(tokens) return { "instruction": f"Aşağıdaki dokümanda gizli bir bilgi var. Soruyu cevapla.\n\nDoküman:\n{long_doc}\n\nSoru: {question}", "response": f"Cevap: {needle.split(': ')[-1]}" }

2. RULER (NVIDIA, Hsieh et al. 2024)#

RULER = NIAH'ın evrimi. 13 task category, multiple needles, irrelevant info ile noise, multi-hop reasoning. Cookbook'un long-context Lab'ları RULER ile eval edilir.
Tasks:
  • Single-needle retrieval
  • Multi-needle (2-4 needle, all return)
  • Multi-key retrieval
  • Variable tracking (önceki turn'lerden state)
  • Common words (frequency-based)
  • Counting
  • ... 13 toplam
Maintainer'ın repo'su:
github.com/NVIDIA/RULER
. Cookbook TR adaptation'ı:
github.com/sukruyusufkaya/ruler-tr
(planlandı).

3. YaRN Rope Scaling — Long-Context'in Mimari Tarafı#

Llama 3.1'in 128K context'i YaRN rope-scaling ile mümkün:
# Llama 3.1 config "rope_scaling": { "factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192, "rope_type": "llama3" }
Model 8K'da pre-trained, 128K'ya YaRN ile interpolate edilmiş. SFT zamanı eğer 8K'da kalırsan rope-scaling'i kullanmazsın, 128K capability lost. Long-context SFT data zorunlu if you want to use long context.

4. RTX 4090'da 128K QLoRA — Memory Math#

Llama 3.1 8B QLoRA + seq_len=128K:
  • W (NF4) = 4 GB
  • A (grad-ckpt + FA + packing):
    128K × 32 × 4096 × 2 / sqrt(32) × 2.5
    = 15 GB
  • O + G + B: ~3 GB
  • Total: ~22 GB — sığar, headroom 2 GB
Batch=1 zorunlu. Throughput çok düşer (saniyede 0.1-0.3 step) ama mümkün.
Cloud alternatifi: 1× H100 80GB'da batch=4, 4-5x hızlı, daha rahat.
python
# === 128K Long-Context QLoRA — RTX 4090 ===
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
 
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype="bfloat16")
 
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
quantization_config=bnb,
attn_implementation="flash_attention_2",
torch_dtype="bfloat16",
device_map="cuda",
)
tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
 
# LoRA — sadece attn + mlp
lora = LoraConfig(r=16, alpha=32, dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM")
model = get_peft_model(model, lora)
 
# SFT — 128K seq, batch=1
cfg = SFTConfig(
output_dir="long-ctx-out",
per_device_train_batch_size=1,
gradient_accumulation_steps=16, # effective batch=16
max_seq_length=131072, # 128K
packing=False, # packing 128K'da bias getirir
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_8bit",
learning_rate=1e-4,
warmup_ratio=0.05,
bf16=True,
logging_steps=1,
save_steps=20,
report_to="wandb",
num_train_epochs=1,
)
 
trainer = SFTTrainer(model=model, tokenizer=tok,
train_dataset=long_ctx_dataset, args=cfg)
trainer.train()
128K long-context QLoRA RTX 4090
✅ Teslim
  1. NIAH synthetic generator yaz, 100 örnek üret. 2) 32K seq ile mini Lab koş (smaller batch). 3) Sonraki ders: 2.12 — DPO/KTO Dataset Engineering.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler