Anatomy of Activation Memory: Why O(L·s·h) and the Real Savings of FlashAttention
Activation memory: forward pass's most misleading memory consumer. Layer-by-layer breakdown (attn intermediates, FFN, norm, residual), FlashAttention's saved-memory math (O(s²)→O(s)), the 'sqrt(L) savings' myth of grad-checkpoint, packing + variable-length attention.
Şükrü Yusuf KAYA
35 min read
Advanced🎯 Hedef
Activation memory'yi 'gizemli bir terim' olmaktan çıkarıp layer-by-layer kâğıt-kalem hesap yapabileceğin bir şeye dönüştürmek. FlashAttention'ın neden bu kadar büyük tasarruf sağladığını matematiksel kavramak. Grad-checkpoint'in 'sqrt(L) tasarruf' söyleminin yarı-doğru olduğunu anlamak.
1. Bir Transformer Layer'ında Tutulan Aktivasyonlar#
Bir layer (decoder-only Llama-tarzı) backward için şunları saklar:
1. Input (residual stream) [b, s, h] 2. RMSNorm output before attn [b, s, h] 3. q, k, v projections 3 × [b, s, h] 4. Attention scores (softmax) [b, n_heads, s, s] ← O(s²) ⚠️ 5. Attention output [b, s, h] 6. Post-attn residual [b, s, h] 7. RMSNorm output before FFN [b, s, h] 8. FFN intermediate (gate × up) [b, s, h_ffn] ← Llama: 14336 (~3.5× h) 9. FFN output [b, s, h] 10. Pre-output residual [b, s, h]
bytes = b × s × h × 2Tek layer toplamı (FlashAttention yok):
~6 × h + n_heads × s + h_ffn = ~6 × 4096 + 32 × 4096 + 14336 = 24576 + 131072 + 14336 = 169984 elements per (b, s) cell
(Llama 3.1 8B: h=4096, n_heads=32, h_ffn=14336, layers=32)
Bütün 32 layer için:
A_no_flash = 1 × 4096 × 32 × 169984 × 2 bytes = 44.5 GB ← felaket
24 GB'a sığmaz. İşte burada FlashAttention devreye giriyor.
2. FlashAttention'ın Memory Tasarrufu#
Standard attention: softmax(QK^T) tensor'ünü tam materyalize eder, sonra V ile çarpar.
S = softmax(Q K^T / √d) # [b, h, s, s] ← O(s²) memory out = S V # [b, h, s, d]
Bu matrisinin saklanması, backward'da gradient'i hesaplamak için lazımdı. Llama 3.1 8B, s=4096 için:
SS_size = 1 × 32 × 4096 × 4096 × 2 bytes per layer = 1 GB per layer × 32 layers = 32 GB
Sadece attention matrisi 32 GB.
FlashAttention çözümü: S'yi hiç materyalize etme#
STile-by-tile online softmax + accumulate. Backward için sakladığı şey:
- (output) — [b, h, s, d]
O - (logsumexp) — [b, h, s] ← O(s), s² değil
L - Bunlar yeterli: backward'ta Q, K, V ile birlikte attention'ı yeniden hesaplar.
A_flash_attn = (O + L) bellek = (b×h×s×d + b×h×s) × 2 bytes ≈ 1 × 32 × 4096 × 128 × 2 + epsilon = 32 MB per layer × 32 = 1 GB
32 GB → 1 GB. Tek bir mimari değişiklik, 31 GB tasarruf.
Bu yüzden cookbook'taki tüm Lab'lar FlashAttention 2 (veya v3, H100 destekli) açıkken yazılır.
python
# === FlashAttention aktif mi kontrol ===import torchfrom transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", # ← KRİTİK device_map="cuda",) print(model.config._attn_implementation) # "flash_attention_2" # Aktif mi gerçekten? mini testx = torch.randn(1, 4096, 4096, device="cuda", dtype=torch.bfloat16)torch.cuda.reset_peak_memory_stats()_ = model(input_ids=torch.randint(0, 128000, (1, 4096), device="cuda"))peak = torch.cuda.max_memory_allocated() / 1024**3print(f"forward peak: {peak:.1f} GB")# FA2: ~3-4 GB • no-FA: ~12+ GBFlashAttention'ı doğru aktive etme + doğrulama
3. Gradient Checkpointing — 'sqrt(L) Tasarruf' Efsanesi#
Klasik söylem: "gradient checkpointing activation memory'i indirir."
O(L) → O(sqrt(L))Yarı-doğru. Detay:
Klasik checkpointing (Sublinear Memory Cost, Chen et al. 2016):#
- Layer'ları segment'e böl
sqrt(L) - Sadece her segment'in girdisini sakla
- Backward'da o segment'i yeniden hesapla
Memory:
Compute: (forward iki kez)
O(sqrt(L) × s × h) + segment_recompute_buffer+%33HuggingFace Trainer'ın default'u:#
gradient_checkpointing=TrueL- Memory: per layer × L → ama her layer için sadece input saved → toplam
O(s × h)(sabit!)O(s × h) - Compute: — aynı.
+%33
Aynı bütçe iki strateji ile aynı sonuca: pratik gerçek tasarruf ~%70-80.
Cookbook ölçümü, Llama 3.1 8B s=4096:
| Config | Activation peak |
|---|---|
| No grad-ckpt | 10.7 GB |
| HF default (per-layer) | 4.2 GB |
| Sqrt-L (8 segments) | 4.5 GB |
| Custom (4 segments) | 5.6 GB |
Karar: HF default yeter; sqrt-L manuel uğraş istemiyor.
4. Sequence Packing — Activation Tasarrufunun En Büyük Single-Win'i#
Klasik FT: dataset'te 50K örnek, ortalama 800 token, max_seq_len=2048 ile padding.
- Effective compute: 50K × 800 = 40M token
- Forward edilen: 50K × 2048 = 102M token
- Boşa giden: %60 padding compute
Sequence packing: birden fazla kısa örneği tek bir 2048-token sequence'ında uçuca ekle. Attention mask'ı her örnek için ayrı (block-diagonal).
Cookbook'un Llama 3.1 8B Lab'ında packing aç → step başına 1.5-2.5x daha fazla effective token (datasetin length distribution'una bağlı).
Memory tarafı: Activations packed sequence'ın gerçek uzunluğuna göre ölçeklenir. Padding kaldığı için aynı memory, daha fazla compute.
python
# === Sequence packing + variable-length FlashAttention ===# Llama 3.1 8B üzerinde TRL SFTTrainer ilefrom trl import SFTTrainer, SFTConfigfrom transformers import AutoModelForCausalLM, AutoTokenizerfrom datasets import load_dataset model = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype="bfloat16", attn_implementation="flash_attention_2", device_map="cuda",)tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")tok.pad_token = tok.eos_token dataset = load_dataset("malhajar/alpaca-gpt4-tr", split="train") cfg = SFTConfig( output_dir="out", num_train_epochs=1, per_device_train_batch_size=2, max_seq_length=2048, packing=True, # ← key dataset_text_field="text", bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="paged_adamw_8bit", learning_rate=2e-4, warmup_ratio=0.03, logging_steps=5, report_to="wandb",) trainer = SFTTrainer(model=model, tokenizer=tok, train_dataset=dataset, args=cfg)trainer.train() # Bench: packing=False vs True throughput karşılaştırma# packing=False: ~7,200 tokens/s (effective: %40 padding)# packing=True: ~11,900 tokens/s (~%65 daha hızlı)TRL SFTTrainer ile sequence packing aktif
🐛 FMD — 'Packing açtım, loss curve aniden 0'a düşüyor 50 step sonra'
Hipotezler: (a) Attention mask block-diagonal değil → örnek A'nın token'ları örnek B'nin token'larına bakıyor → leakage, loss sahte 0'a düşer. Çözüm: TRL 0.12+ varsayılan olarak block-diagonal mask. `packing=True` ile `max_seq_length=2048` veriliyse `SFTTrainer` doğru yapar. (b) EOS token'ı dataset'e konmamış → modeller örneklerin sınırını anlamaz. Çözüm: `tokenizer.add_special_tokens({'eos_token': '<|im_end|>'})` ve dataset'in her örneğine eos ekle. (c) Loss masking yanlış: instruction kısmının da loss'a girdi → modeller 'cevabı kopyala' öğreniyor. Drill: TRL `train_on_responses_only` template handler doğru ayarlandı mı kontrol et.
5. Bench Tablosu (Llama 3.1 8B + RTX 4090)#
| Config | Activation peak | step/s | tokens/s (eff) | Sığar 24GB'a? |
|---|---|---|---|---|
| No FA, no grad-ckpt, no pack | OOM | n/a | n/a | ❌ |
| FA2 only | 10.7 GB | 1.30 | 5300 | ✅ ama gergin |
| FA2 + grad-ckpt | 4.2 GB | 1.10 | 4500 | ✅ rahat |
| FA2 + grad-ckpt + packing | 4.5 GB | 1.78 | 11900 | ✅✅ optimal |
| + Unsloth fused | 4.0 GB | 3.10 | 20700 | ✅✅✅ |
✅ Teslim
- Bir Llama 3.2 3B'yi 3 modda çalıştır: FA kapalı, FA açık, FA + grad-ckpt + packing. 2) Peak memory'yi her birinde ölç. 3) Teorik tahminle gerçek peak arasındaki farkı raporla. 4) Sonraki ders: 1.3 — Gradient Checkpointing Trade-off Lab.
Frequently Asked Questions
FA v3 H100 native (FP8 + WGMMA). RTX 4090 (Ada) için **v2.7.x** doğru sürüm. v3'ün hopper-spesifik kernel'leri Ada'da düşük performans verir. v2.7.0.post2 cookbook standardı.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Part 0 — Engineering Foundations
Welcome to the Fine-Tuning Cookbook: System, Stage Taxonomy, and the Reproducibility Contract
Start LearningPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags, and Deterministic CUDA — End the 'Works on My Machine' Problem
Start LearningPart 0 — Engineering Foundations