Skip to content

Gradient Checkpointing Trade-off Lab: Compressing Memory by Crediting Compute

Decision tree for gradient checkpointing: per-layer, segment-based, custom selective? Re-entrant vs non-re-entrant difference, torch.utils.checkpoint vs HF Trainer kwargs, selective checkpointing. 5-strategy bench on RTX 4090 + Llama 3.1 8B.

Şükrü Yusuf KAYA
28 min read
Advanced
Gradient Checkpointing Trade-off Lab: Memory'yi Kompresleyip Compute'ı Krediye Yatırmak
🎯 Bu Lab'da
Aynı modeli 5 ayrı checkpointing stratejisiyle koşacaksın: hiç yok, HF Trainer default, sqrt-segment, selective (sadece attn), Unsloth-style. Memory peak ile throughput trade-off'unu tablo değil deneyim olarak kavrayacaksın.

1. Re-entrant vs Non-re-entrant — En Çok Yapılan Hata#

torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True/False)
arasındaki seçim kritik:
ModDavranışTrade-off
use_reentrant=True
(default <2.5)
Autograd graph içinde kapatılmış custom functionDaha az memory ama early-return bug, control-flow ile uyumsuz
use_reentrant=False
(modern)
Saved-tensor hook mekanizması%2-5 fazla memory ama bug-free, control-flow OK
PyTorch ≥ 2.5'te
use_reentrant=False
kuvvetle önerilir. Eski model kodları (örn. HuggingFace eski) hâlâ
True
default'u kullanır → cookbook'un her Trainer config'inde override:
trainer_args.gradient_checkpointing = True trainer_args.gradient_checkpointing_kwargs = {"use_reentrant": False}

2. 5 Strateji#

StratejiMemoryCompute overheadİmplementasyon
(a) Off100%0%grad-ckpt kapalı
(b) Per-layer (HF default)~35%+%33
enable_input_require_grads()
+ her layer'a wrapper
(c) Sqrt-segment~38%+%33sqrt(L) segment, manual
(d) Selective (attn-only)~55%+%18sadece attn modülünü çechkpoint, FFN bypass
(e) Unsloth selective~30%+%12Triton fused kernel + selective
(d) selective: attention'ın activation'ı en pahalı (O(s²) tensors). FFN intermediate sadece h_ffn × s, daha küçük. Selective: sadece attention'ı recompute, FFN olduğu gibi sakla → memory %55 (aradaki kayıp), compute %18 (yarısı).
python
# === Selective gradient checkpointing — sadece attn'i recompute et ===
import torch
from torch.utils.checkpoint import checkpoint
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
# Monkey-patch decoder layer'ı: sadece self_attn'ı checkpoint, MLP'yi değil
original_forward = LlamaDecoderLayer.forward
 
def selective_forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_value=None, output_attentions=False, use_cache=False, **kw):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
 
# Self-attention'ı checkpoint
def attn_fn(h):
return self.self_attn(
hidden_states=h, attention_mask=attention_mask,
position_ids=position_ids, past_key_value=past_key_value,
output_attentions=output_attentions, use_cache=use_cache,
)
 
attn_out = checkpoint(attn_fn, hidden_states, use_reentrant=False)
if isinstance(attn_out, tuple):
hidden_states = attn_out[0]
else:
hidden_states = attn_out
hidden_states = residual + hidden_states
 
# FFN'ı checkpoint'siz çalıştır — memory'sini sakla, recompute'a girme
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
 
return (hidden_states,)
 
LlamaDecoderLayer.forward = selective_forward
# Now training with selective ckpt
selective gradient checkpointing — sadece attention'ı recompute

3. Bench Sonuçları (Llama 3.1 8B + RTX 4090 + QLoRA r=32)#

50 step warmup + 100 step ölçüm, seq_len=4096, batch=2.
StratejiPeak GBstep/stokens/swall-clock 1 epoch
(a) OffOOM
(b) HF per-layer (default)13.41.7872901h 22m
(c) Sqrt-segment14.21.8375001h 19m
(d) Selective attn-only17.12.0885201h 10m
(e) Unsloth selective11.83.101270047m
Karar tablosu:
  • Memory en sıkı → (b) HF per-layer
  • Throughput en yüksek + sığar → (e) Unsloth
  • Manual control gerekli → (d) selective
  • Sıkıntı yok bütçe rahat → (a) Off ama 8B için OOM
Cookbook default'u: (e) Unsloth + selective (Llama / Qwen / Gemma için Unsloth destekli). Diğer modeller için (b) HF per-layer.
✅ Lab teslim
  1. 5 stratejiyi de RTX 4090'ında koş. 2) Memory + throughput tablosunu W&B'de side-by-side raporla. 3) Sonraki ders: 1.4 — Mixed Precision Mimarisi (bf16 vs fp16 vs fp8).

Frequently Asked Questions

\`torch.distributed.algorithms._checkpoint.checkpoint_wrapper\` selective_ckpt_fn API'sı sunuyor (PyTorch 2.4+). HF transformers tarafında entegrasyon ergonomic değil — cookbook manual monkey-patch gösterir. Production'da Axolotl'un \`gradient_checkpointing: 'selective'\` config option'ı bunu sağlıyor.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content