PyTorch vs JAX vs torch.compile: Practical Comparison of Eager, Static, and Hybrid
Theoretical difference from 2.2 → practical benchmark. Implement the same transformer block in PyTorch eager, JAX jit, torch.compile (reduce-overhead, max-autotune) modes. Compile time, throughput, memory, debug experience side-by-side. Which framework for which scenario in 2026?
Şükrü Yusuf KAYA
50 min read
Intermediate⚖️ Soyut tartışmadan somut benchmark'a
2.2'de eager vs static felsefe konuştuk. Şimdi pratik: aynı transformer attention bloğunu PyTorch eager, JAX jit, torch.compile (3 mod) ile yazıp gerçek sayılarla karşılaştıracağız. 50 dakika sonra hangi framework'ün hangi senaryoda doğru olduğunu sezgisel + sayısal olarak bileceksin.
Ders Haritası#
- Test bench: aynı attention bloğu, 3 framework
- PyTorch eager baseline
- default mode
torch.compile - mode='reduce-overhead'
torch.compile - mode='max-autotune'
torch.compile - JAX + Flax implementation
jit - Benchmark protocol: warmup, repetition, std
- Sonuçlar: latency, throughput, memory, compile time
- Production karar matrisi
- Hibrit yaklaşımlar 2026'da
1. Test Bench: Aynı Attention Bloğu#
Bench olarak scaled dot-product attention with GQA kullanacağız. Mode: forward + backward, batch 4, seq 1024, d_model 512, 8 head, GQA group 4.
Spec#
B = 4 # batch T = 1024 # sequence length d = 512 # d_model H = 8 # n_heads G = 4 # n_kv_heads (GQA group) d_h = d // H # head dim = 64
Bu küçük ama gerçekçi. Llama 3 8B'nin tek bir attention bloğunun ~%10'u boyutunda.
2. PyTorch Eager — Baseline#
import torch import torch.nn as nn import torch.nn.functional as F class Attention(nn.Module): def __init__(self, d_model=512, n_heads=8, n_kv_heads=4): super().__init__() self.d_h = d_model // n_heads self.n_h = n_heads self.n_kv = n_kv_heads self.group = n_heads // n_kv_heads self.q_proj = nn.Linear(d_model, n_heads * self.d_h, bias=False) self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False) self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_h, self.d_h).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv, self.d_h).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv, self.d_h).transpose(1, 2) # GQA: K, V'yi her group için repeat k = k.repeat_interleave(self.group, dim=1) v = v.repeat_interleave(self.group, dim=1) # Scaled dot-product scores = (q @ k.transpose(-2, -1)) / (self.d_h ** 0.5) attn = F.softmax(scores, dim=-1) out = attn @ v out = out.transpose(1, 2).reshape(B, T, -1) return self.o_proj(out) device = "cuda" model_eager = Attention().to(device) x = torch.randn(4, 1024, 512, device=device, requires_grad=True) # Forward + backward y = model_eager(x) y.sum().backward()
3-5. torch.compile Modları#
torch.compilePyTorch 2.5+'de 3 farklı mod var:
Mode: default (reduce-overhead değil ama similar)#
reduce-overheadtorch.compile(model)Mode: reduce-overhead#
reduce-overheadtorch.compile(model, mode='reduce-overhead')Mode: max-autotune#
max-autotunetorch.compile(model, mode='max-autotune')model_compiled_default = torch.compile(Attention()).to(device) model_compiled_reduce = torch.compile(Attention(), mode='reduce-overhead').to(device) model_compiled_maxauto = torch.compile(Attention(), mode='max-autotune').to(device)
Compile time karşılaştırması (tipik H100)#
| Mode | İlk forward | İkinci forward | Memory overhead |
|---|---|---|---|
| eager | 5 ms | 5 ms | minimal |
| default | 8 sec (compile) | 4.3 ms | +50 MB |
| reduce-overhead | 12 sec | 3.5 ms | +200 MB (CUDA graph) |
| max-autotune | 60-180 sec | 3.0 ms | +300 MB |
6. JAX + Flax Implementation#
import jax import jax.numpy as jnp import flax.linen as nn from functools import partial class AttentionFlax(nn.Module): d_model: int = 512 n_heads: int = 8 n_kv_heads: int = 4 def setup(self): d_h = self.d_model // self.n_heads self.d_h = d_h self.group = self.n_heads // self.n_kv_heads self.q_proj = nn.Dense(self.n_heads * d_h, use_bias=False) self.k_proj = nn.Dense(self.n_kv_heads * d_h, use_bias=False) self.v_proj = nn.Dense(self.n_kv_heads * d_h, use_bias=False) self.o_proj = nn.Dense(self.d_model, use_bias=False) def __call__(self, x): B, T, _ = x.shape q = self.q_proj(x).reshape(B, T, self.n_heads, self.d_h).transpose(0, 2, 1, 3) k = self.k_proj(x).reshape(B, T, self.n_kv_heads, self.d_h).transpose(0, 2, 1, 3) v = self.v_proj(x).reshape(B, T, self.n_kv_heads, self.d_h).transpose(0, 2, 1, 3) k = jnp.repeat(k, self.group, axis=1) v = jnp.repeat(v, self.group, axis=1) scores = (q @ k.swapaxes(-1, -2)) / jnp.sqrt(self.d_h) attn = jax.nn.softmax(scores, axis=-1) out = attn @ v out = out.transpose(0, 2, 1, 3).reshape(B, T, -1) return self.o_proj(out) # Init model_jax = AttentionFlax() key = jax.random.PRNGKey(0) x_jax = jax.random.normal(key, (4, 1024, 512)) params = model_jax.init(key, x_jax) # JIT compile @jax.jit def forward_jax(params, x): return model_jax.apply(params, x) # Eğitim için: grad of loss @jax.jit def loss_and_grad(params, x): def loss_fn(p): return forward_jax(p, x).sum() return jax.value_and_grad(loss_fn)(params)
Compile time#
JAX ilk çağrı: ~3-8 saniye (XLA compile). Sonraki çağrılar: ~3-4 ms (PyTorch'la benzer).
jit7. Benchmark Protocol — Doğru Ölçüm#
Yanlış ölçümün klasik bug'larından kaçınmak için:
1. Warmup#
İlk birkaç çağrı compile, allocator warming, etc. — atla.
for _ in range(5): y = model(x) y.sum().backward() torch.cuda.synchronize()
2. torch.cuda.synchronize()#
torch.cuda.synchronize()GPU async — arasında sync olmazsa zaman ölçümü yanlış.
time.perf_counter()3. torch.cuda.Event (daha doğru)#
torch.cuda.Eventstart = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() # ... work ... end.record() torch.cuda.synchronize() ms = start.elapsed_time(end)
4. Çok tekrar + std#
times = [] for _ in range(100): start.record() y = model(x) y.sum().backward() end.record() torch.cuda.synchronize() times.append(start.elapsed_time(end)) times = np.array(times) print(f"Mean: {times.mean():.2f} ms ± {times.std():.2f}")
5. torch.utils.benchmark.Timer#
torch.utils.benchmark.TimerPyTorch'un kendi util'i bunu otomatik yapıyor.
6. Memory ölçümü#
torch.cuda.reset_peak_memory_stats() y = model(x) y.sum().backward() peak = torch.cuda.max_memory_allocated() / 1024**2 print(f"Peak memory: {peak:.1f} MB")
python
# Tam benchmark — 5 framework, throughput + memoryimport torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport time # (Attention class yukarıdaki gibi) def benchmark(model, x, n_iters=100, warmup=10): """GPU üzerinde forward+backward latency ölç.""" torch.cuda.reset_peak_memory_stats() # Warmup for _ in range(warmup): y = model(x) y.sum().backward() torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) times = [] for _ in range(n_iters): x.grad = None for p in model.parameters(): p.grad = None start.record() y = model(x) y.sum().backward() end.record() torch.cuda.synchronize() times.append(start.elapsed_time(end)) times = np.array(times) peak_mb = torch.cuda.max_memory_allocated() / 1024**2 return { 'mean_ms': times.mean(), 'std_ms': times.std(), 'p99_ms': np.percentile(times, 99), 'peak_mem_mb': peak_mb, } device = "cuda"B, T, d = 4, 1024, 512 models = { 'eager': Attention().to(device), 'compile_default': torch.compile(Attention()).to(device), 'compile_reduce': torch.compile(Attention(), mode='reduce-overhead').to(device), 'compile_maxauto': torch.compile(Attention(), mode='max-autotune').to(device),} print(f"{'Mode':<25} {'Mean (ms)':>10} {'Std':>8} {'P99':>8} {'Mem (MB)':>10}")for name, model in models.items(): x = torch.randn(B, T, d, device=device, requires_grad=True) r = benchmark(model, x) print(f"{name:<25} {r['mean_ms']:>10.2f} {r['std_ms']:>8.2f} {r['p99_ms']:>8.2f} {r['peak_mem_mb']:>10.1f}") # Beklenen output (H100):# Mode Mean (ms) Std P99 Mem (MB)# eager 5.20 0.15 5.51 120.5# compile_default 4.30 0.12 4.65 135.2# compile_reduce 3.50 0.08 3.78 198.4# compile_maxauto 3.00 0.06 3.25 245.8Tam karşılaştırma benchmark'ı.
8. Tipik Sonuçlar (H100 üzerinde)#
| Mode | Forward+Backward | Compile time | Memory | Best for |
|---|---|---|---|---|
| PyTorch eager | 5.2 ms | 0 | minimum | Geliştirme, debugging |
| JAX jit | 4.5 ms | 5-8 sec | +50 MB | Research, TPU |
| torch.compile default | 4.3 ms | 8 sec | +50 MB | Genel production |
| torch.compile reduce-overhead | 3.5 ms | 12 sec | +200 MB | Latency-critical inference |
| torch.compile max-autotune | 3.0 ms | 60-180 sec | +300 MB | Throughput-max production |
Anahtar gözlemler#
- Eager → compile: ~%20-40 hızlanma. Compile time ödenir.
- default → reduce-overhead: küçük model + küçük batch için ek %20 hızlanma (CUDA graph + Python overhead azaltma).
- max-autotune: en yüksek throughput ama compile time saniyeler değil dakikalar.
- JAX vs torch.compile default: yakın. JAX TPU'da hâlâ avantajlı.
- Memory: compile modlarında biraz artıyor (compiled kernel'ler için cache).
9. Karar Matrisi: Hangi Framework Hangi Senaryoda?#
Senaryo 1: Erken araştırma / prototype#
Önerilen: PyTorch eager.
Neden: hız ile zaman değil, iterasyon hızı önemli. , print, dynamic shape hep çalışıyor.
pdbSenaryo 2: Stable training pipeline#
Önerilen: PyTorch + .
Neden: %20-30 hızlanma, debug çoğu zaman OK, compile time tolere edilebilir.
torch.compile(default)Senaryo 3: Frontier research with TPU#
Önerilen: JAX + Flax.
Neden: TPU'lar XLA'ya optimize. Anthropic, DeepMind, Google Brain seçimi.
Senaryo 4: Inference, latency-critical#
Önerilen: veya TensorRT-LLM.
Neden: P99 latency düşür, CUDA graph kullan.
torch.compile(mode='reduce-overhead')Senaryo 5: Inference, throughput-max (vLLM, TGI)#
Önerilen: vLLM veya SGLang (Modül 35).
Neden: 'dan ötesi — paged attention, batching, kernel fusion. bunlardan altta kullanılıyor zaten.
torch.compiletorch.compileSenaryo 6: Hibrit#
Önerilen: PyTorch eager + decorator.
Neden: training loop eager (debug), critical functions compile.
torch.compileSenaryo 7: Edge / Mobile#
Önerilen: MLX (Apple), ONNX Runtime, TVM.
Neden: Apple Silicon için MLX native, çok küçük modeller için ONNX.
10. 2026'da Hibrit Yaklaşımlar#
torch.compile decorator pattern#
torch.compileimport torch @torch.compile def attention_block(q, k, v): scores = q @ k.transpose(-2, -1) / (q.size(-1) ** 0.5) return torch.softmax(scores, dim=-1) @ v # Training loop eager — debug rahat for batch in dataloader: x = batch.to(device) y = model(x) # eager qkv = projections(y) # eager out = attention_block(*qkv.chunk(3, dim=-1)) # compiled! loss = criterion(out, target) loss.backward()
Strategic compilation#
# Tüm modeli compile etme — sadece hot path'leri for name, module in model.named_modules(): if isinstance(module, nn.Linear) and module.out_features > 1024: # Büyük linear'ları compile et compiled = torch.compile(module, mode='reduce-overhead') setattr(model, name, compiled)
functorch / torch.func ile JAX-vari#
functorchtorch.funcPyTorch (eski functorch) JAX'in transform composition'ını PyTorch'a getiriyor. Modern LLM kodu giderek bu yöne kayıyor — eager + JAX-vari functional primitives.
torch.func11. Mini Egzersizler#
-
Compile time amortize:180 saniye compile time alıyor. Bunu amortize etmek için kaç forward çağrısı gerekir? (Eager 5 ms, compile sonrası 3 ms.)
torch.compile(max-autotune) -
Compile time vs first-token latency: LLM serving'de ilk istek 180 saniye beklesin mi? Pratik çözüm?
-
Memory budget: H100 80GB. Eager mode 70 GB kullanıyor. Compile +300 MB. Hangi compile mode bunu allow eder?
-
Dynamic shape: prompt length değişiyor (her istekte farklı).her şape için re-compile mi? Çözüm?
torch.compile -
vLLM vs torch.compile: vLLM zaten compile-edilmiş gibi davranıyor.ile birlikte mi yoksa yerine mi?
torch.compile
Bu Derste Neler Öğrendik?#
✓ Aynı transformer bloğunu 5 framework'te yazma
✓ PyTorch eager, JAX jit, torch.compile (3 mod) compile time + latency karşılaştırması
✓ Doğru benchmark protocol: warmup, sync, percentile, memory tracking
✓ Tipik H100 sonuçları: eager 5.2 ms → max-autotune 3.0 ms (~%42 hızlanma)
✓ 7 senaryo için framework önerisi
✓ Hibrit yaklaşımlar 2026'da — decorator-based selective compile
✓ vLLM + torch.compile birlikte kullanım
Modül 2 sonuna geldik#
Sıradaki Modül 3 — Derin Öğrenmenin Felsefi Tarihi
Perceptron'dan transformer'a 70 yıllık yolculuk. Connectionism vs symbolic tartışması, ImageNet/AlexNet/ResNet milestone'ları, attention'a giden yol. Sonraki ders sonrası Modül 4 ile LLM'lerin zihinsel modeline geçeceğiz.
Frequently Asked Questions
Several causes: (1) **Graph breaks**: `print`, Python-side control flow, custom Python objects → partial compile. Detect with `TORCH_LOGS='graph_breaks'`. (2) **Too small model** (<10M params): overhead exceeds gain. (3) **Already optimized**: NVIDIA Apex fused kernels do the same. (4) **Dynamic shape**: re-compile per shape → time wasted. **Solution**: profile, resolve graph_breaks, write compilation-friendly model.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Module 0: Course Framework & Workshop Setup
Who Is an LLM Engineer? The AI Engineering Career Ladder from Junior to Staff
Start LearningModule 0: Course Framework & Workshop Setup
Course Philosophy: Why This Path, Why This Order — The Skeleton of an 8-Month Curriculum
Start LearningModule 0: Course Framework & Workshop Setup