Gradient Checkpointing Trade-off Lab: Memory'yi Kompresleyip Compute'ı Krediye Yatırmak
Gradient checkpointing'in seçim ağacı: per-layer mı, segment-based mi, custom selective mi? Re-entrant vs non-re-entrant farkı, torch.utils.checkpoint vs HF Trainer kwargs, selective checkpointing (sadece attn'i checkpoint et, FFN'i değil). RTX 4090 + Llama 3.1 8B üzerinde 5 strateji bench'i.
Şükrü Yusuf KAYA
28 dakikalık okuma
İleri🎯 Bu Lab'da
Aynı modeli 5 ayrı checkpointing stratejisiyle koşacaksın: hiç yok, HF Trainer default, sqrt-segment, selective (sadece attn), Unsloth-style. Memory peak ile throughput trade-off'unu tablo değil deneyim olarak kavrayacaksın.
1. Re-entrant vs Non-re-entrant — En Çok Yapılan Hata#
torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True/False)| Mod | Davranış | Trade-off |
|---|---|---|
use_reentrant=True | Autograd graph içinde kapatılmış custom function | Daha az memory ama early-return bug, control-flow ile uyumsuz |
use_reentrant=False | Saved-tensor hook mekanizması | %2-5 fazla memory ama bug-free, control-flow OK |
PyTorch ≥ 2.5'te kuvvetle önerilir. Eski model kodları (örn. HuggingFace eski) hâlâ default'u kullanır → cookbook'un her Trainer config'inde override:
use_reentrant=FalseTruetrainer_args.gradient_checkpointing = True trainer_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
2. 5 Strateji#
| Strateji | Memory | Compute overhead | İmplementasyon |
|---|---|---|---|
| (a) Off | 100% | 0% | grad-ckpt kapalı |
| (b) Per-layer (HF default) | ~35% | +%33 | enable_input_require_grads() |
| (c) Sqrt-segment | ~38% | +%33 | sqrt(L) segment, manual |
| (d) Selective (attn-only) | ~55% | +%18 | sadece attn modülünü çechkpoint, FFN bypass |
| (e) Unsloth selective | ~30% | +%12 | Triton fused kernel + selective |
(d) selective: attention'ın activation'ı en pahalı (O(s²) tensors). FFN intermediate sadece h_ffn × s, daha küçük. Selective: sadece attention'ı recompute, FFN olduğu gibi sakla → memory %55 (aradaki kayıp), compute %18 (yarısı).
python
# === Selective gradient checkpointing — sadece attn'i recompute et ===import torchfrom torch.utils.checkpoint import checkpointfrom transformers.models.llama.modeling_llama import LlamaDecoderLayer # Monkey-patch decoder layer'ı: sadece self_attn'ı checkpoint, MLP'yi değiloriginal_forward = LlamaDecoderLayer.forward def selective_forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kw): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self-attention'ı checkpoint def attn_fn(h): return self.self_attn( hidden_states=h, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) attn_out = checkpoint(attn_fn, hidden_states, use_reentrant=False) if isinstance(attn_out, tuple): hidden_states = attn_out[0] else: hidden_states = attn_out hidden_states = residual + hidden_states # FFN'ı checkpoint'siz çalıştır — memory'sini sakla, recompute'a girme residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return (hidden_states,) LlamaDecoderLayer.forward = selective_forward# Now training with selective ckptselective gradient checkpointing — sadece attention'ı recompute
3. Bench Sonuçları (Llama 3.1 8B + RTX 4090 + QLoRA r=32)#
50 step warmup + 100 step ölçüm, seq_len=4096, batch=2.
| Strateji | Peak GB | step/s | tokens/s | wall-clock 1 epoch |
|---|---|---|---|---|
| (a) Off | OOM | — | — | — |
| (b) HF per-layer (default) | 13.4 | 1.78 | 7290 | 1h 22m |
| (c) Sqrt-segment | 14.2 | 1.83 | 7500 | 1h 19m |
| (d) Selective attn-only | 17.1 | 2.08 | 8520 | 1h 10m |
| (e) Unsloth selective | 11.8 | 3.10 | 12700 | 47m |
Karar tablosu:
- Memory en sıkı → (b) HF per-layer
- Throughput en yüksek + sığar → (e) Unsloth
- Manual control gerekli → (d) selective
- Sıkıntı yok bütçe rahat → (a) Off ama 8B için OOM
Cookbook default'u: (e) Unsloth + selective (Llama / Qwen / Gemma için Unsloth destekli). Diğer modeller için (b) HF per-layer.
✅ Lab teslim
- 5 stratejiyi de RTX 4090'ında koş. 2) Memory + throughput tablosunu W&B'de side-by-side raporla. 3) Sonraki ders: 1.4 — Mixed Precision Mimarisi (bf16 vs fp16 vs fp8).
Sık Sorulan Sorular
\`torch.distributed.algorithms._checkpoint.checkpoint_wrapper\` selective_ckpt_fn API'sı sunuyor (PyTorch 2.4+). HF transformers tarafında entegrasyon ergonomic değil — cookbook manual monkey-patch gösterir. Production'da Axolotl'un \`gradient_checkpointing: 'selective'\` config option'ı bunu sağlıyor.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
İlgili İçerikler
Part 0 — Engineering Foundations
Fine-Tuning Cookbook'a Hoş Geldin: Sistematik, Stage Taksonomisi ve Reproducibility Kontratı
Öğrenmeye BaşlaPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags ve Deterministic CUDA — 'Sende Niye Çalışıyor Bende Çalışmıyor' Sorununu Bitir
Öğrenmeye BaşlaPart 0 — Engineering Foundations