Skip to content

Anatomy of Activation Memory: Why O(L·s·h) and the Real Savings of FlashAttention

Activation memory: forward pass's most misleading memory consumer. Layer-by-layer breakdown (attn intermediates, FFN, norm, residual), FlashAttention's saved-memory math (O(s²)→O(s)), the 'sqrt(L) savings' myth of grad-checkpoint, packing + variable-length attention.

Şükrü Yusuf KAYA
35 min read
Advanced
Activation Memory Anatomisi: Niye O(L·s·h) ve FlashAttention'ın Gerçek Tasarrufu
🎯 Hedef
Activation memory'yi 'gizemli bir terim' olmaktan çıkarıp layer-by-layer kâğıt-kalem hesap yapabileceğin bir şeye dönüştürmek. FlashAttention'ın neden bu kadar büyük tasarruf sağladığını matematiksel kavramak. Grad-checkpoint'in 'sqrt(L) tasarruf' söyleminin yarı-doğru olduğunu anlamak.

1. Bir Transformer Layer'ında Tutulan Aktivasyonlar#

Bir layer (decoder-only Llama-tarzı) backward için şunları saklar:
1. Input (residual stream) [b, s, h] 2. RMSNorm output before attn [b, s, h] 3. q, k, v projections 3 × [b, s, h] 4. Attention scores (softmax) [b, n_heads, s, s] ← O(s²) ⚠️ 5. Attention output [b, s, h] 6. Post-attn residual [b, s, h] 7. RMSNorm output before FFN [b, s, h] 8. FFN intermediate (gate × up) [b, s, h_ffn] ← Llama: 14336 (~3.5× h) 9. FFN output [b, s, h] 10. Pre-output residual [b, s, h]
bytes = b × s × h × 2
(bf16)
Tek layer toplamı (FlashAttention yok):
~6 × h + n_heads × s + h_ffn = ~6 × 4096 + 32 × 4096 + 14336 = 24576 + 131072 + 14336 = 169984 elements per (b, s) cell
(Llama 3.1 8B: h=4096, n_heads=32, h_ffn=14336, layers=32)
Bütün 32 layer için:
A_no_flash = 1 × 4096 × 32 × 169984 × 2 bytes = 44.5 GB ← felaket
24 GB'a sığmaz. İşte burada FlashAttention devreye giriyor.

2. FlashAttention'ın Memory Tasarrufu#

Standard attention: softmax(QK^T) tensor'ünü tam materyalize eder, sonra V ile çarpar.
S = softmax(Q K^T / √d) # [b, h, s, s] ← O(s²) memory out = S V # [b, h, s, d]
Bu
S
matrisinin saklanması, backward'da gradient'i hesaplamak için lazımdı. Llama 3.1 8B, s=4096 için:
S_size = 1 × 32 × 4096 × 4096 × 2 bytes per layer = 1 GB per layer × 32 layers = 32 GB
Sadece attention matrisi 32 GB.

FlashAttention çözümü:
S
'yi hiç materyalize etme#

Tile-by-tile online softmax + accumulate. Backward için sakladığı şey:
  • O
    (output) — [b, h, s, d]
  • L
    (logsumexp) — [b, h, s] ← O(s), s² değil
  • Bunlar yeterli: backward'ta Q, K, V ile birlikte attention'ı yeniden hesaplar.
A_flash_attn = (O + L) bellek = (b×h×s×d + b×h×s) × 2 bytes ≈ 1 × 32 × 4096 × 128 × 2 + epsilon = 32 MB per layer × 32 = 1 GB
32 GB → 1 GB. Tek bir mimari değişiklik, 31 GB tasarruf.
Bu yüzden cookbook'taki tüm Lab'lar FlashAttention 2 (veya v3, H100 destekli) açıkken yazılır.
python
# === FlashAttention aktif mi kontrol ===
import torch
from transformers import AutoModelForCausalLM
 
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # ← KRİTİK
device_map="cuda",
)
 
print(model.config._attn_implementation) # "flash_attention_2"
 
# Aktif mi gerçekten? mini test
x = torch.randn(1, 4096, 4096, device="cuda", dtype=torch.bfloat16)
torch.cuda.reset_peak_memory_stats()
_ = model(input_ids=torch.randint(0, 128000, (1, 4096), device="cuda"))
peak = torch.cuda.max_memory_allocated() / 1024**3
print(f"forward peak: {peak:.1f} GB")
# FA2: ~3-4 GB • no-FA: ~12+ GB
FlashAttention'ı doğru aktive etme + doğrulama

3. Gradient Checkpointing — 'sqrt(L) Tasarruf' Efsanesi#

Klasik söylem: "gradient checkpointing activation memory'i
O(L) → O(sqrt(L))
indirir."
Yarı-doğru. Detay:

Klasik checkpointing (Sublinear Memory Cost, Chen et al. 2016):#

  • Layer'ları
    sqrt(L)
    segment'e böl
  • Sadece her segment'in girdisini sakla
  • Backward'da o segment'i yeniden hesapla
Memory:
O(sqrt(L) × s × h) + segment_recompute_buffer
Compute:
+%33
(forward iki kez)

HuggingFace Trainer'ın default'u:#

gradient_checkpointing=True
→ segment_count =
L
(her layer ayrı checkpoint).
  • Memory:
    O(s × h)
    per layer × L → ama her layer için sadece input saved → toplam
    O(s × h)
    (sabit!)
  • Compute:
    +%33
    — aynı.
Aynı bütçe iki strateji ile aynı sonuca: pratik gerçek tasarruf ~%70-80.
Cookbook ölçümü, Llama 3.1 8B s=4096:
ConfigActivation peak
No grad-ckpt10.7 GB
HF default (per-layer)4.2 GB
Sqrt-L (8 segments)4.5 GB
Custom (4 segments)5.6 GB
Karar: HF default yeter; sqrt-L manuel uğraş istemiyor.

4. Sequence Packing — Activation Tasarrufunun En Büyük Single-Win'i#

Klasik FT: dataset'te 50K örnek, ortalama 800 token, max_seq_len=2048 ile padding.
  • Effective compute: 50K × 800 = 40M token
  • Forward edilen: 50K × 2048 = 102M token
  • Boşa giden: %60 padding compute
Sequence packing: birden fazla kısa örneği tek bir 2048-token sequence'ında uçuca ekle. Attention mask'ı her örnek için ayrı (block-diagonal).
Cookbook'un Llama 3.1 8B Lab'ında packing aç → step başına 1.5-2.5x daha fazla effective token (datasetin length distribution'una bağlı).
Memory tarafı: Activations packed sequence'ın gerçek uzunluğuna göre ölçeklenir. Padding kaldığı için aynı memory, daha fazla compute.
python
# === Sequence packing + variable-length FlashAttention ===
# Llama 3.1 8B üzerinde TRL SFTTrainer ile
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
 
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
torch_dtype="bfloat16",
attn_implementation="flash_attention_2",
device_map="cuda",
)
tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
tok.pad_token = tok.eos_token
 
dataset = load_dataset("malhajar/alpaca-gpt4-tr", split="train")
 
cfg = SFTConfig(
output_dir="out",
num_train_epochs=1,
per_device_train_batch_size=2,
max_seq_length=2048,
packing=True, # ← key
dataset_text_field="text",
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_8bit",
learning_rate=2e-4,
warmup_ratio=0.03,
logging_steps=5,
report_to="wandb",
)
 
trainer = SFTTrainer(model=model, tokenizer=tok, train_dataset=dataset, args=cfg)
trainer.train()
 
# Bench: packing=False vs True throughput karşılaştırma
# packing=False: ~7,200 tokens/s (effective: %40 padding)
# packing=True: ~11,900 tokens/s (~%65 daha hızlı)
TRL SFTTrainer ile sequence packing aktif
🐛 FMD — 'Packing açtım, loss curve aniden 0'a düşüyor 50 step sonra'
Hipotezler: (a) Attention mask block-diagonal değil → örnek A'nın token'ları örnek B'nin token'larına bakıyor → leakage, loss sahte 0'a düşer. Çözüm: TRL 0.12+ varsayılan olarak block-diagonal mask. `packing=True` ile `max_seq_length=2048` veriliyse `SFTTrainer` doğru yapar. (b) EOS token'ı dataset'e konmamış → modeller örneklerin sınırını anlamaz. Çözüm: `tokenizer.add_special_tokens({'eos_token': '<|im_end|>'})` ve dataset'in her örneğine eos ekle. (c) Loss masking yanlış: instruction kısmının da loss'a girdi → modeller 'cevabı kopyala' öğreniyor. Drill: TRL `train_on_responses_only` template handler doğru ayarlandı mı kontrol et.

5. Bench Tablosu (Llama 3.1 8B + RTX 4090)#

ConfigActivation peakstep/stokens/s (eff)Sığar 24GB'a?
No FA, no grad-ckpt, no packOOMn/an/a
FA2 only10.7 GB1.305300✅ ama gergin
FA2 + grad-ckpt4.2 GB1.104500✅ rahat
FA2 + grad-ckpt + packing4.5 GB1.7811900✅✅ optimal
+ Unsloth fused4.0 GB3.1020700✅✅✅
✅ Teslim
  1. Bir Llama 3.2 3B'yi 3 modda çalıştır: FA kapalı, FA açık, FA + grad-ckpt + packing. 2) Peak memory'yi her birinde ölç. 3) Teorik tahminle gerçek peak arasındaki farkı raporla. 4) Sonraki ders: 1.3 — Gradient Checkpointing Trade-off Lab.

Frequently Asked Questions

FA v3 H100 native (FP8 + WGMMA). RTX 4090 (Ada) için **v2.7.x** doğru sürüm. v3'ün hopper-spesifik kernel'leri Ada'da düşük performans verir. v2.7.0.post2 cookbook standardı.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content