İçeriğe geç

FlashAttention: IO-Aware Attention — Dao 2022 Algoritması ve Modern Implementations

FlashAttention'ın matematiksel ve sistemsel anatomi: niye standard attention memory-bound, GPU memory hierarchy (HBM vs SRAM), tile-based computation, online softmax, recomputation backward. FlashAttention-1 (Dao 2022), FlashAttention-2, FlashAttention-3 evrimi. PyTorch flash_attn library, performance benchmarks, long context enablement.

Şükrü Yusuf KAYA
75 dakikalık okuma
İleri
FlashAttention: IO-Aware Attention — Dao 2022 Algoritması ve Modern Implementations
⚡ FlashAttention — long context'i mümkün kılan algoritma
Vaswani 2017'den FlashAttention'a kadar 5 yıl boyunca attention quadratic memory ile sınırlandı. 8K context modelin pratik tavanıydı. Tri Dao 2022'de FlashAttention paper'ını yayınladığında her şey değişti: aynı matematiksel sonuç, ama dramatically less memory. Llama-3 128K context, GPT-4o 128K context, Claude 200K — hepsi FlashAttention sayesinde mümkün. Anahtar fikir: attention matrix'i hiç tam materialize etme. GPU memory hierarchy'i (HBM yavaş + SRAM hızlı) tile-by-tile akıllı kullan. Online softmax matematiği ile incremental compute. 75 dakika sonra: FlashAttention'ın sistemsel ve matematiksel temellerini, FA-1/2/3 evrimini, production library kullanımını derinlemesine kavramış olacaksın.

Ders Haritası (10 Bölüm)#

  1. GPU memory hierarchy — HBM vs SRAM trade-off
  2. Standard attention bottleneck — memory-bound analizi
  3. FlashAttention core insight — IO-aware computation
  4. Online softmax — incremental matematik
  5. Tile-based algorithm — block-by-block computation
  6. Backward pass — recomputation strategy
  7. FlashAttention-1 (Dao 2022) — original paper
  8. FlashAttention-2 — improvements (Dao 2023)
  9. FlashAttention-3 — H100/Hopper optimizations (2024)
  10. Production usage — flash_attn library, PyTorch SDPA

1. GPU Memory Hierarchy#

1.1 NVIDIA H100 specs#

HBM (High Bandwidth Memory): 80 GB, 3.35 TB/s bandwidth SRAM (per SM): 232 KB, 33 TB/s effective Registers: 256 KB per SM, ~100 TB/s
SRAM 10x faster ama 100,000x daha küçük. Mühendislik problemi: data'yı doğru place et.

1.2 Standard attention memory traffic#

For seq=2048, d=128, fp16:
  • Q, K, V each: 2048 × 128 × 2 byte = 512 KB → HBM
  • QK^T matrix: 2048 × 2048 × 2 byte = 8 MB → HBM (büyük!)
  • weights matrix: aynı 8 MB → HBM
  • Output: 512 KB → HBM
Total HBM traffic: ~17 MB read + 17 MB write = 34 MB total per attention layer.

1.3 Math vs memory bottleneck#

Attention compute: ~512 MFLOP per layer (seq²d operations). H100 BF16 throughput: 989 TFLOPS theoretical. Attention compute time pure: 512M / 989T = 0.5 microseconds.
Memory transfer: 34 MB / 3.35 TB/s = 10 microseconds.
Memory takes 20x longer than compute. GPU is mostly idle waiting for memory.

1.4 Bandwidth utilization#

Naive attention: ~10% FLOP utilization. GPU mostly waiting. FlashAttention: ~70%+ FLOP utilization. Memory traffic dramatic azalır.

4-5. Online Softmax + Tile-Based#

4.1 Standard softmax#

softmax(x) = [exp(x_i) / Σ_j exp(x_j)]
Numerical stability için max-subtract trick:
m = max(x) softmax(x) = [exp(x_i - m) / Σ_j exp(x_j - m)]
Problem: tüm x'i bir kez okumak gerekir (max bulma + sum compute).

4.2 Online softmax (Milakov & Gimelshein 2018)#

Incremental compute: tüm x'i bir seferde görmeden softmax compute.
State: (m_running_max, d_running_denom).
Her yeni x_new geldiğinde:
m_new = max(m_running_max, x_new) d_new = d_running_denom × exp(m_running_max - m_new) + exp(x_new - m_new) m_running_max = m_new d_running_denom = d_new
Sonuçta softmax(x) compute edilebilir without seeing whole x at once.

4.3 Tile-based attention#

FA approach: Q, K, V matrixlerini tiles (bloklar) halinde işle.
Tile size: typically 64x64 or 128x128 (SRAM-fits) For each Q tile Q_i: For each K, V tile K_j, V_j: Compute partial QK^T for this tile Update online softmax statistics Update output partial sum
Key: attention matrix tile'ı SRAM'de compute → never store full attention matrix in HBM.

4.4 Memory savings#

Standard: O(seq²) memory for attention matrix. FlashAttention: O(seq) memory (only tile + running state).
seq=128K context:
  • Standard: 128K² × 2 byte = 32 GB attention matrix per head — IMPOSSIBLE
  • FlashAttention: 128K × 2 byte = 256 KB — fits in SRAM

4.5 Speed gain#

Less HBM traffic → less memory bandwidth bottleneck. FA-2 on H100: 70%+ throughput utilization (vs 10% naive). Long context (>8K) speedup 2-4x.

4.6 Backward pass: recomputation#

Forward: tile-based, attention matrix never stored. Backward needs attention weights — but they aren't saved.
Solution: recompute during backward. Cost: ~30% extra FLOP, but memory savings dominate.

4.7 FlashAttention 1 vs 2 vs 3#

  • FA-1 (Dao 2022): original paper, 2-4x speedup, O(seq) memory
  • FA-2 (Dao 2023): better parallelism, +2x faster than FA-1, 70%+ FLOP util
  • FA-3 (2024): H100/Hopper-specific (async, wgmma), +1.5-2x faster than FA-2
2026 production: FA-3 standard on H100, FA-2 on A100.
python
# FlashAttention production usage
# Option 1: PyTorch native (FA-2 backend)
import torch
import torch.nn.functional as F
 
Q = torch.randn(2, 32, 8192, 128, dtype=torch.bfloat16, device='cuda')
K = torch.randn_like(Q)
V = torch.randn_like(Q)
 
# Use SDPA — FA-2 automatically selected on supported HW
out = F.scaled_dot_product_attention(
Q, K, V,
is_causal=True,
dropout_p=0.0,
)
# 4x faster than naive attention on H100
 
# Option 2: flash-attn library (manual)
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
 
# Shape: [batch, seq, n_heads, d_head]
q = torch.randn(2, 8192, 32, 128, dtype=torch.bfloat16, device='cuda')
k = torch.randn_like(q)
v = torch.randn_like(q)
 
out = flash_attn_func(q, k, v, causal=True)
# FA-3 (H100) or FA-2 (A100) auto-selected
 
# Benchmarks on H100 80GB
import time
def bench(fn, n=100):
torch.cuda.synchronize()
start = time.time()
for _ in range(n):
_ = fn()
torch.cuda.synchronize()
return (time.time() - start) / n * 1000 # ms
 
seq_lens = [512, 2048, 8192, 32768]
for seq in seq_lens:
Q = torch.randn(1, 32, seq, 128, dtype=torch.bfloat16, device='cuda')
K = torch.randn_like(Q)
V = torch.randn_like(Q)
naive_time = bench(lambda: F.scaled_dot_product_attention(Q, K, V, is_causal=True))
print(f"seq={seq}: SDPA (FA-2) = {naive_time:.2f} ms")
FlashAttention production — PyTorch SDPA + flash-attn library
✅ Ders 8.3 Özeti — FlashAttention
FlashAttention (Dao 2022) attention'i memory-bound olmaktan compute-bound'a taşıdı. GPU memory hierarchy (HBM yavaş + SRAM hızlı) tile-based computation ile akıllı kullan. Online softmax matematiği ile incremental tile process. Memory: O(seq²) → O(seq). Speed: 2-4x. Long context (128K+) enabling technology. FA-1 → FA-2 → FA-3 evrim. Production: PyTorch SDPA otomatik FA-2 backend. H100'de FA-3 70%+ throughput utilization. Ders 8.4'te KV cache + paged attention'a geçeceğiz: inference optimization frontier.

Sıradaki Ders: KV Cache + Paged Attention#

Ders 8.4: autoregressive inference KV cache anatomy, prefill vs decode phases, paged attention (vLLM 2023), continuous batching, production serving optimization.

Sık Sorulan Sorular

EVET, bit-by-bit identical (modulo numerical precision differences within fp16/bf16). FA mathematically equivalent — sadece sistemsel optimization. Output guaranteed same.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler