Skip to content

PyTorch vs JAX vs torch.compile: Practical Comparison of Eager, Static, and Hybrid

Theoretical difference from 2.2 → practical benchmark. Implement the same transformer block in PyTorch eager, JAX jit, torch.compile (reduce-overhead, max-autotune) modes. Compile time, throughput, memory, debug experience side-by-side. Which framework for which scenario in 2026?

Şükrü Yusuf KAYA
50 min read
Intermediate
PyTorch vs JAX vs torch.compile: Eager, Static ve Hybrid'in Pratik Karşılaştırması
⚖️ Soyut tartışmadan somut benchmark'a
2.2'de eager vs static felsefe konuştuk. Şimdi pratik: aynı transformer attention bloğunu PyTorch eager, JAX jit, torch.compile (3 mod) ile yazıp gerçek sayılarla karşılaştıracağız. 50 dakika sonra hangi framework'ün hangi senaryoda doğru olduğunu sezgisel + sayısal olarak bileceksin.

Ders Haritası#

  1. Test bench: aynı attention bloğu, 3 framework
  2. PyTorch eager baseline
  3. torch.compile
    default mode
  4. torch.compile
    mode='reduce-overhead'
  5. torch.compile
    mode='max-autotune'
  6. JAX
    jit
    + Flax
    implementation
  7. Benchmark protocol: warmup, repetition, std
  8. Sonuçlar: latency, throughput, memory, compile time
  9. Production karar matrisi
  10. Hibrit yaklaşımlar 2026'da

1. Test Bench: Aynı Attention Bloğu#

Bench olarak scaled dot-product attention with GQA kullanacağız. Mode: forward + backward, batch 4, seq 1024, d_model 512, 8 head, GQA group 4.

Spec#

B = 4 # batch T = 1024 # sequence length d = 512 # d_model H = 8 # n_heads G = 4 # n_kv_heads (GQA group) d_h = d // H # head dim = 64
Bu küçük ama gerçekçi. Llama 3 8B'nin tek bir attention bloğunun ~%10'u boyutunda.

2. PyTorch Eager — Baseline#

import torch import torch.nn as nn import torch.nn.functional as F class Attention(nn.Module): def __init__(self, d_model=512, n_heads=8, n_kv_heads=4): super().__init__() self.d_h = d_model // n_heads self.n_h = n_heads self.n_kv = n_kv_heads self.group = n_heads // n_kv_heads self.q_proj = nn.Linear(d_model, n_heads * self.d_h, bias=False) self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False) self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_h, self.d_h).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv, self.d_h).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv, self.d_h).transpose(1, 2) # GQA: K, V'yi her group için repeat k = k.repeat_interleave(self.group, dim=1) v = v.repeat_interleave(self.group, dim=1) # Scaled dot-product scores = (q @ k.transpose(-2, -1)) / (self.d_h ** 0.5) attn = F.softmax(scores, dim=-1) out = attn @ v out = out.transpose(1, 2).reshape(B, T, -1) return self.o_proj(out) device = "cuda" model_eager = Attention().to(device) x = torch.randn(4, 1024, 512, device=device, requires_grad=True) # Forward + backward y = model_eager(x) y.sum().backward()

3-5.
torch.compile
Modları#

PyTorch 2.5+'de 3 farklı mod var:

Mode: default (
reduce-overhead
değil ama similar)#

torch.compile(model)
— minimum optimization, hızlı compile.

Mode:
reduce-overhead
#

torch.compile(model, mode='reduce-overhead')
— CUDA Graphs kullanarak Python overhead'i azaltır. Küçük batch / latency-critical inference için ideal.

Mode:
max-autotune
#

torch.compile(model, mode='max-autotune')
— Triton ile birden çok kernel autotune. En yüksek throughput, en uzun compile time.
model_compiled_default = torch.compile(Attention()).to(device) model_compiled_reduce = torch.compile(Attention(), mode='reduce-overhead').to(device) model_compiled_maxauto = torch.compile(Attention(), mode='max-autotune').to(device)

Compile time karşılaştırması (tipik H100)#

Modeİlk forwardİkinci forwardMemory overhead
eager5 ms5 msminimal
default8 sec (compile)4.3 ms+50 MB
reduce-overhead12 sec3.5 ms+200 MB (CUDA graph)
max-autotune60-180 sec3.0 ms+300 MB

6. JAX + Flax Implementation#

import jax import jax.numpy as jnp import flax.linen as nn from functools import partial class AttentionFlax(nn.Module): d_model: int = 512 n_heads: int = 8 n_kv_heads: int = 4 def setup(self): d_h = self.d_model // self.n_heads self.d_h = d_h self.group = self.n_heads // self.n_kv_heads self.q_proj = nn.Dense(self.n_heads * d_h, use_bias=False) self.k_proj = nn.Dense(self.n_kv_heads * d_h, use_bias=False) self.v_proj = nn.Dense(self.n_kv_heads * d_h, use_bias=False) self.o_proj = nn.Dense(self.d_model, use_bias=False) def __call__(self, x): B, T, _ = x.shape q = self.q_proj(x).reshape(B, T, self.n_heads, self.d_h).transpose(0, 2, 1, 3) k = self.k_proj(x).reshape(B, T, self.n_kv_heads, self.d_h).transpose(0, 2, 1, 3) v = self.v_proj(x).reshape(B, T, self.n_kv_heads, self.d_h).transpose(0, 2, 1, 3) k = jnp.repeat(k, self.group, axis=1) v = jnp.repeat(v, self.group, axis=1) scores = (q @ k.swapaxes(-1, -2)) / jnp.sqrt(self.d_h) attn = jax.nn.softmax(scores, axis=-1) out = attn @ v out = out.transpose(0, 2, 1, 3).reshape(B, T, -1) return self.o_proj(out) # Init model_jax = AttentionFlax() key = jax.random.PRNGKey(0) x_jax = jax.random.normal(key, (4, 1024, 512)) params = model_jax.init(key, x_jax) # JIT compile @jax.jit def forward_jax(params, x): return model_jax.apply(params, x) # Eğitim için: grad of loss @jax.jit def loss_and_grad(params, x): def loss_fn(p): return forward_jax(p, x).sum() return jax.value_and_grad(loss_fn)(params)

Compile time#

JAX
jit
ilk çağrı: ~3-8 saniye (XLA compile). Sonraki çağrılar: ~3-4 ms (PyTorch'la benzer).

7. Benchmark Protocol — Doğru Ölçüm#

Yanlış ölçümün klasik bug'larından kaçınmak için:

1. Warmup#

İlk birkaç çağrı compile, allocator warming, etc. — atla.
for _ in range(5): y = model(x) y.sum().backward() torch.cuda.synchronize()

2.
torch.cuda.synchronize()
#

GPU async —
time.perf_counter()
arasında sync olmazsa zaman ölçümü yanlış.

3.
torch.cuda.Event
(daha doğru)#

start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() # ... work ... end.record() torch.cuda.synchronize() ms = start.elapsed_time(end)

4. Çok tekrar + std#

times = [] for _ in range(100): start.record() y = model(x) y.sum().backward() end.record() torch.cuda.synchronize() times.append(start.elapsed_time(end)) times = np.array(times) print(f"Mean: {times.mean():.2f} ms ± {times.std():.2f}")

5.
torch.utils.benchmark.Timer
#

PyTorch'un kendi util'i bunu otomatik yapıyor.

6. Memory ölçümü#

torch.cuda.reset_peak_memory_stats() y = model(x) y.sum().backward() peak = torch.cuda.max_memory_allocated() / 1024**2 print(f"Peak memory: {peak:.1f} MB")
python
# Tam benchmark — 5 framework, throughput + memory
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
 
# (Attention class yukarıdaki gibi)
 
def benchmark(model, x, n_iters=100, warmup=10):
"""GPU üzerinde forward+backward latency ölç."""
torch.cuda.reset_peak_memory_stats()
# Warmup
for _ in range(warmup):
y = model(x)
y.sum().backward()
torch.cuda.synchronize()
 
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for _ in range(n_iters):
x.grad = None
for p in model.parameters():
p.grad = None
start.record()
y = model(x)
y.sum().backward()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
 
times = np.array(times)
peak_mb = torch.cuda.max_memory_allocated() / 1024**2
return {
'mean_ms': times.mean(),
'std_ms': times.std(),
'p99_ms': np.percentile(times, 99),
'peak_mem_mb': peak_mb,
}
 
device = "cuda"
B, T, d = 4, 1024, 512
 
models = {
'eager': Attention().to(device),
'compile_default': torch.compile(Attention()).to(device),
'compile_reduce': torch.compile(Attention(), mode='reduce-overhead').to(device),
'compile_maxauto': torch.compile(Attention(), mode='max-autotune').to(device),
}
 
print(f"{'Mode':<25} {'Mean (ms)':>10} {'Std':>8} {'P99':>8} {'Mem (MB)':>10}")
for name, model in models.items():
x = torch.randn(B, T, d, device=device, requires_grad=True)
r = benchmark(model, x)
print(f"{name:<25} {r['mean_ms']:>10.2f} {r['std_ms']:>8.2f} {r['p99_ms']:>8.2f} {r['peak_mem_mb']:>10.1f}")
 
# Beklenen output (H100):
# Mode Mean (ms) Std P99 Mem (MB)
# eager 5.20 0.15 5.51 120.5
# compile_default 4.30 0.12 4.65 135.2
# compile_reduce 3.50 0.08 3.78 198.4
# compile_maxauto 3.00 0.06 3.25 245.8
Tam karşılaştırma benchmark'ı.

8. Tipik Sonuçlar (H100 üzerinde)#

ModeForward+BackwardCompile timeMemoryBest for
PyTorch eager5.2 ms0minimumGeliştirme, debugging
JAX jit4.5 ms5-8 sec+50 MBResearch, TPU
torch.compile default4.3 ms8 sec+50 MBGenel production
torch.compile reduce-overhead3.5 ms12 sec+200 MBLatency-critical inference
torch.compile max-autotune3.0 ms60-180 sec+300 MBThroughput-max production

Anahtar gözlemler#

  1. Eager → compile: ~%20-40 hızlanma. Compile time ödenir.
  2. default → reduce-overhead: küçük model + küçük batch için ek %20 hızlanma (CUDA graph + Python overhead azaltma).
  3. max-autotune: en yüksek throughput ama compile time saniyeler değil dakikalar.
  4. JAX vs torch.compile default: yakın. JAX TPU'da hâlâ avantajlı.
  5. Memory: compile modlarında biraz artıyor (compiled kernel'ler için cache).

9. Karar Matrisi: Hangi Framework Hangi Senaryoda?#

Senaryo 1: Erken araştırma / prototype#

Önerilen: PyTorch eager. Neden: hız ile zaman değil, iterasyon hızı önemli.
pdb
, print, dynamic shape hep çalışıyor.

Senaryo 2: Stable training pipeline#

Önerilen: PyTorch +
torch.compile(default)
. Neden: %20-30 hızlanma, debug çoğu zaman OK, compile time tolere edilebilir.

Senaryo 3: Frontier research with TPU#

Önerilen: JAX + Flax. Neden: TPU'lar XLA'ya optimize. Anthropic, DeepMind, Google Brain seçimi.

Senaryo 4: Inference, latency-critical#

Önerilen:
torch.compile(mode='reduce-overhead')
veya TensorRT-LLM. Neden: P99 latency düşür, CUDA graph kullan.

Senaryo 5: Inference, throughput-max (vLLM, TGI)#

Önerilen: vLLM veya SGLang (Modül 35). Neden:
torch.compile
'dan ötesi — paged attention, batching, kernel fusion.
torch.compile
bunlardan altta kullanılıyor zaten.

Senaryo 6: Hibrit#

Önerilen: PyTorch eager +
torch.compile
decorator
. Neden: training loop eager (debug), critical functions compile.

Senaryo 7: Edge / Mobile#

Önerilen: MLX (Apple), ONNX Runtime, TVM. Neden: Apple Silicon için MLX native, çok küçük modeller için ONNX.

10. 2026'da Hibrit Yaklaşımlar#

torch.compile
decorator pattern#

import torch @torch.compile def attention_block(q, k, v): scores = q @ k.transpose(-2, -1) / (q.size(-1) ** 0.5) return torch.softmax(scores, dim=-1) @ v # Training loop eager — debug rahat for batch in dataloader: x = batch.to(device) y = model(x) # eager qkv = projections(y) # eager out = attention_block(*qkv.chunk(3, dim=-1)) # compiled! loss = criterion(out, target) loss.backward()

Strategic compilation#

# Tüm modeli compile etme — sadece hot path'leri for name, module in model.named_modules(): if isinstance(module, nn.Linear) and module.out_features > 1024: # Büyük linear'ları compile et compiled = torch.compile(module, mode='reduce-overhead') setattr(model, name, compiled)

functorch
/
torch.func
ile JAX-vari#

PyTorch
torch.func
(eski functorch) JAX'in transform composition'ını PyTorch'a getiriyor. Modern LLM kodu giderek bu yöne kayıyor — eager + JAX-vari functional primitives.

11. Mini Egzersizler#

  1. Compile time amortize:
    torch.compile(max-autotune)
    180 saniye compile time alıyor. Bunu amortize etmek için kaç forward çağrısı gerekir? (Eager 5 ms, compile sonrası 3 ms.)
  2. Compile time vs first-token latency: LLM serving'de ilk istek 180 saniye beklesin mi? Pratik çözüm?
  3. Memory budget: H100 80GB. Eager mode 70 GB kullanıyor. Compile +300 MB. Hangi compile mode bunu allow eder?
  4. Dynamic shape: prompt length değişiyor (her istekte farklı).
    torch.compile
    her şape için re-compile mi? Çözüm?
  5. vLLM vs torch.compile: vLLM zaten compile-edilmiş gibi davranıyor.
    torch.compile
    ile birlikte mi yoksa yerine mi?

Bu Derste Neler Öğrendik?#

✓ Aynı transformer bloğunu 5 framework'te yazma ✓ PyTorch eager, JAX jit, torch.compile (3 mod) compile time + latency karşılaştırması ✓ Doğru benchmark protocol: warmup, sync, percentile, memory tracking ✓ Tipik H100 sonuçları: eager 5.2 ms → max-autotune 3.0 ms (~%42 hızlanma) ✓ 7 senaryo için framework önerisiHibrit yaklaşımlar 2026'da — decorator-based selective compile ✓ vLLM + torch.compile birlikte kullanım

Modül 2 sonuna geldik#

Sıradaki Modül 3 — Derin Öğrenmenin Felsefi Tarihi Perceptron'dan transformer'a 70 yıllık yolculuk. Connectionism vs symbolic tartışması, ImageNet/AlexNet/ResNet milestone'ları, attention'a giden yol. Sonraki ders sonrası Modül 4 ile LLM'lerin zihinsel modeline geçeceğiz.

Frequently Asked Questions

Several causes: (1) **Graph breaks**: `print`, Python-side control flow, custom Python objects → partial compile. Detect with `TORCH_LOGS='graph_breaks'`. (2) **Too small model** (<10M params): overhead exceeds gain. (3) **Already optimized**: NVIDIA Apex fused kernels do the same. (4) **Dynamic shape**: re-compile per shape → time wasted. **Solution**: profile, resolve graph_breaks, write compilation-friendly model.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content