Skip to content

FlashAttention v2/v3 Internals: Tile + Online Softmax + Hopper WGMMA

FlashAttention's mathematical heart: tile-by-tile attention compute, **online softmax** (incremental running max + sum), backward recomputation strategy. v2 → v3 difference: Hopper WGMMA, async memory, FP8 attention. Head-size constraint, deterministic mode, varlen variant.

Şükrü Yusuf KAYA
36 min read
Advanced
FlashAttention v2/v3 Internals: Tile + Online Softmax + Hopper WGMMA

1. Online Softmax Matematiği#

Standart softmax:
softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
.
max(x)
ve
sum(...)
tüm vector'ü ister.
Online softmax (tile-by-tile):
Init: m = -inf, l = 0, O = 0 For each tile B_i: m_new = max(m, max(B_i)) scale_prev = exp(m - m_new) scale_curr = exp(max(B_i) - m_new) l = l × scale_prev + scale_curr × sum(exp(B_i - max(B_i))) O = O × scale_prev + scale_curr × (exp(B_i - max(B_i)) × V_i) m = m_new Final: O / l
Sezgi: Her tile için running max ve running sum güncelle. Tüm matrix'i materyalize etmeden softmax + matmul aynı anda.

2. FlashAttention v2 vs v3#

Aspectv2 (2023)v3 (2024, Hopper-only)
ArchitectureAda/Ampere + HopperHopper-only (H100, H200)
Tile schedulestaticproducer-consumer asynchronous
MatmulTMA + WMMAWGMMA (warp group MMA, 128 thread coordinated)
Memory loadsyncasync via TMA (Tensor Memory Accelerator)
PrecisionFP16/BF16+ FP8 (e4m3 + e5m2)
Throughput (H100)540 TFLOPS740 TFLOPS (FP8: 1200)
Speedup vs naïve4-8×6-12×
RTX 4090 (Ada) için: v2.7.x kullan, v3 H100-optimized.
✅ Teslim
  1. FlashAttention paper v2 + v3 oku (özellikle algoritma section). 2) Online softmax'ı NumPy ile sıfırdan implement et. 3) Sonraki ders: 13.2 — Triton Crash Course.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content