torch.compile and torch.fx: Graph Capture, JIT Compilation, and Production Optimization
PyTorch 2.0+ game-changer feature torch.compile in depth: TorchDynamo + TorchInductor + Triton pipeline, FX graph manipulation, compile modes (default/reduce-overhead/max-autotune), graph breaks debugging, dynamic shapes, production trade-offs. Production extension of Module 2.5.
Şükrü Yusuf KAYA
60 min read
Advanced⚡ PyTorch'un 'mucize katmanı'
Modül 2.5'te PyTorch eager vs JAX static graph karşılaştırmasını yaptık. torch.compile (PyTorch 2.0, Mart 2023) ikisinin de avantajını birleştirdi: eager Python'u arkadan compile ederek static graph performansı. 60 dakika sonra: TorchDynamo Python bytecode'u nasıl yakalıyor, TorchInductor neden Triton'a derliyor, graph break'leri nasıl debug edersin — production'da torch.compile sahibi olacaksın.
Ders Haritası (Detaylı)#
- torch.compile'ın 3 katmanı: Dynamo, Inductor, Triton
- TorchDynamo: Python bytecode capture
- TorchInductor: graph → optimized kernels
- Triton backend: GPU code generation
- torch.fx: symbolic graph manipulation
- Compile modes: default vs reduce-overhead vs max-autotune
- Graph breaks: tespit, debug, çözüm
- Dynamic shapes: 2.1+ improvements
- Production deployment patterns
- LLM inference'ta torch.compile
- Compile time vs runtime trade-off
- vLLM, TGI ile etkileşim
1. torch.compile'ın 3 Katmanı#
torch.compile tek bir şey değil — bir stack:
Python code ↓ [1] TorchDynamo — Python bytecode → FX graph ↓ [2] TorchInductor — FX graph → optimized lower-level IR ↓ [3] Triton (default) — IR → GPU kernels ↓ GPU
Niye katmanlı?#
Her katman bir abstraction level:
- Dynamo: "ne yapıyor?" (graph capture)
- Inductor: "nasıl optimize ederim?" (fusion, scheduling)
- Triton: "GPU'da nasıl çalıştırırım?" (kernel codegen)
Alternative backends#
Inductor default backend'tir; başkaları da var:
- nvFuser (NVIDIA): NVIDIA-specific optimization
- OpenXLA: Google XLA-based
- AOTInductor: ahead-of-time compilation (production ready)
- TensorRT: NVIDIA TensorRT integration (2024+)
Mode flag#
@torch.compile(backend="inductor") # default @torch.compile(backend="aot_eager") # debugging @torch.compile(backend="cudagraphs") # CUDA graphs only
2. TorchDynamo — Python Bytecode Capture#
Problem: Python kodu nasıl static graph'a çevrilir?
Önceki denemeler (TorchScript, TorchTrace) kısmen başarılıydı:
- TorchScript: Python subset (önce yazılması zor)
- TorchTrace: bir input ile trace (control flow kayboluyor)
Dynamo yaklaşımı#
CPython interpreter'ın frame evaluation hook'unu kullan. Python bytecode'u runtime'da intercept et, graph'a çevir.
Akış#
1. Python function call (örn. model.forward()) 2. Dynamo bytecode'u inspect eder 3. Tensor operations'i FX graph'a aktarır 4. Non-tensor Python (control flow, prints) "graph break" oluşturur 5. Graph'lar compile edilir, Python kalanı eager kalır
Graph breaks#
Dynamo bazı şeyleri graph'a koyamaz:
- ,
print()logging - Python data structure modifications
- Custom Python class methods (bazı durumlar)
- Numpy operations
- Custom autograd Function (kısmen)
Bu durumda graph break olur — function ikiye bölünür: önce compiled, sonra eager.
Frame evaluation API#
Python 3.11+'de Dynamo daha iyi (frame eval API hooks geliştirildi). 3.10'da çalışır ama performance daha düşük.
3. TorchInductor — FX Graph → Optimized IR#
Dynamo FX graph üretir. Şimdi optimize et.
Inductor'un işlevleri#
- Lowering: high-level ops (aten.matmul) → low-level primitives
- Fusion: ardışık op'ları birleştir (mul + add → fused multiply-add)
- Memory planning: tensor lifetimes, optimal allocation
- Scheduling: GPU stream'lerde paralel execution
- Code generation: Triton (GPU) veya C++ (CPU) için kernel kodu
Operator fusion örnek#
y = x.relu() + x.relu() * 2
Naive: 3 separate kernel (relu, mul, add) → 3 memory round-trips.
Inductor fusion: 1 kernel, in-register computation → 3x speedup.
Inductor IR#
Internal representation, FX'ten daha low-level. Loop nests, memory accesses açık.
Code generation#
GPU için Triton kernel üretir:
@triton.jit def fused_relu_mul_add(x_ptr, y_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < n x = tl.load(x_ptr + offsets, mask=mask) relu_x = tl.where(x > 0, x, 0) y = relu_x + relu_x * 2 tl.store(y_ptr + offsets, y, mask=mask)
Bu otomatik üretilen kod, hand-written'a yakın hız.
4. Triton Backend — GPU Code Generation#
Triton (Tillet 2019, OpenAI) — Python syntax ile CUDA kernel yazma. PyTorch torch.compile default GPU backend.
Niye Triton?#
- Python syntax: CUDA C++'tan kolay
- Auto-tuning: block size, num warps otomatik
- JIT compilation: runtime'da derle
- Performans: hand-tuned CUDA'ya yakın (%80-95)
Inductor'un Triton kullanımı#
Inductor pattern'leri tanıyıp template Triton kernel'a doldurur:
- Element-wise ops (mul, add, relu)
- Reductions (sum, mean, max)
- Pointwise + reduction fused
Custom Triton (Modül 37)#
torch.compile'ın otomatik Triton'u yeterli olmazsa, manuel Triton yazarsın. Modül 37'de detayda.
CPU backend#
CPU için Inductor C++ kodu üretir, sonra OpenMP ile derler. GPU'dan daha az kazanç ama yine ~%20-50 speedup mümkün.
5. torch.fx — Symbolic Graph Manipulation#
torch.fx torch.compile'dan eskidir (PyTorch 1.8+). Bağımsız bir araç: PyTorch model'i symbolic graph olarak temsil et + manipulate et.
Fx GraphModule#
import torch import torch.fx as fx class MyModel(torch.nn.Module): def forward(self, x): return torch.relu(x * 2 + 1) model = MyModel() traced = fx.symbolic_trace(model) print(traced.graph) # graph(): # %x : [#users=1] = placeholder[target=x] # %mul : [#users=1] = call_function[target=mul](args=(x, 2)) # %add : [#users=1] = call_function[target=add](args=(mul, 1)) # %relu : [#users=1] = call_function[target=torch.relu](args=(add,)) # return relu
Manipulation#
Graph'ı inspect + modify edebilirsin:
for node in traced.graph.nodes: if node.op == "call_function" and node.target is torch.relu: node.target = torch.nn.functional.leaky_relu # ReLU'yu LeakyReLU yap traced.graph.lint() traced.recompile()
Use cases#
- Quantization: weight'leri INT8'e dönüştür
- Profiling: hangi node ne kadar süre alıyor
- Architecture search: graph mutation ile NAS
- Custom transformations: domain-specific optimization
torch.compile ile ilişki#
Dynamo internally fx graph üretir. Inductor onu lower-level IR'a çevirir. Sen torch.fx'i direkt kullanmadan compile'dan faydalanırsın — ama özel optimization istiyorsan torch.fx ile pre-processing yapabilirsin.
6. Compile Modes — default, reduce-overhead, max-autotune#
torch.compile(model) # default mode torch.compile(model, mode="default") # eşdeğer torch.compile(model, mode="reduce-overhead") # latency-critical torch.compile(model, mode="max-autotune") # max throughput
Default mode#
- Hızlı compile time (~10 sec)
- Moderate optimization
- Compatible most workloads
- Good baseline
reduce-overhead#
- CUDA graphs kullanır
- Python overhead minimize (kernel launch overhead düşür)
- Küçük batch / latency-critical inference için ideal
- Compile time daha uzun (~30 sec)
- Memory %20-50 artar
max-autotune#
- En agresif optimization
- Triton kernel'lar için autotuning (block size, num warps)
- Compile time çok uzun (~5-15 dakika!)
- En yüksek throughput
- Production batched inference için ideal
Decision matrix#
| Senaryo | Mode |
|---|---|
| Geliştirme, hızlı iteration | default |
| Production training | default veya reduce-overhead |
| Production inference (single request) | reduce-overhead |
| Production inference (batched) | max-autotune |
| Latency SLA critical | reduce-overhead |
| Throughput max'leme | max-autotune |
| Compile time tolere edilemez | default |
python
import torchimport time torch.set_float32_matmul_precision("high") class MLP(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(1024, 4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.ReLU(), torch.nn.Linear(4096, 1024), ) def forward(self, x): return self.layers(x) device = "cuda"model = MLP().to(device)x = torch.randn(32, 1024, device=device) # Eager baselinedef bench(fn, n_warmup=10, n=100): for _ in range(n_warmup): fn() torch.cuda.synchronize() start = time.perf_counter() for _ in range(n): fn() torch.cuda.synchronize() return (time.perf_counter() - start) / n * 1000 t_eager = bench(lambda: model(x))print(f"Eager: {t_eager:.3f} ms") # Defaultmodel_d = torch.compile(model, mode="default")t_d = bench(lambda: model_d(x))print(f"Default: {t_d:.3f} ms ({t_eager/t_d:.2f}x)") # Reduce-overheadmodel_r = torch.compile(model, mode="reduce-overhead")t_r = bench(lambda: model_r(x))print(f"Reduce-overhead: {t_r:.3f} ms ({t_eager/t_r:.2f}x)") # Max-autotune (3-10 dakika compile!)model_m = torch.compile(model, mode="max-autotune")t_m = bench(lambda: model_m(x))print(f"Max-autotune: {t_m:.3f} ms ({t_eager/t_m:.2f}x)") # H100 typical:# Eager: 0.42 ms# Default: 0.31 ms (1.35x)# Reduce-overhead: 0.18 ms (2.33x)# Max-autotune: 0.14 ms (3.0x)Compile modes karşılaştırma benchmark.
7. Graph Breaks — Tespit ve Çözüm#
Dynamo bazı kodları graph'a koyamaz → graph break.
Yaygın sebepler#
- ,
print(): side effectlogging.info() - Python data structure manipulation: dict.pop(), list.append()
- Custom Python class kullanımı (some)
- ops
numpy - context manager (eskiden, 2.1+ improved)
torch.no_grad() - Item access ()
tensor.item() - Boolean tensor in if ()
if x > 0 - Custom autograd Function (mostly OK in 2.4+)
Etki#
Graph break = function ikiye böler:
- Önceki kısım: compiled
- Break: eager (slow)
- Sonra: compiled
Çok graph break = neredeyse eager = compile yararsız.
Detection#
import torch._dynamo as dynamo dynamo.config.verbose = True dynamo.config.suppress_errors = False # Detailed graph break log import os os.environ["TORCH_LOGS"] = "graph_breaks"
veya programmatik:
explanation = dynamo.explain(model, x) print(explanation) # Hangi node'da break, ne kadar fragmenta olduğu
Çözüm pattern'leri#
- print() yerine logger callback (compile-after)
- Tensor.item() kaldır — symbolic shape kullan
- if branches: torch.where() ile dispatch
- Custom op'lar: veya
torch.compile.disableile escapetorch._dynamo.disable - Custom autograd: ile register
@torch.library.custom_op
Practical workflow#
# 1. Compile et model_c = torch.compile(model) # 2. Graph break'leri log'la os.environ["TORCH_LOGS"] = "graph_breaks" model_c(x) # 3. Identify et, çözüm uygula # 4. Recompile + test
8. Dynamic Shapes — 2.1+ Improvements#
PyTorch 2.0'da torch.compile her unique input shape için re-compile ediyordu. Bu LLM inference için fatal (her sequence length farklı).
2.1+ dynamic shapes#
torch.compile(model, dynamic=True) # tüm shape'ler dynamic torch.compile(model, dynamic=False) # static (recompile per shape) torch.compile(model, dynamic=None) # auto (default)
dynamic=NoneSymbolic shapes#
Dynamo shape'leri integer symbols olarak tutar. Örn. .
s0 = batch_size, s1 = seq_len@torch.compile def model(x): return x.sum(dim=-1) # x shape (4, 128) → ilk compile # x shape (8, 256) → re-compile (dynamic=None ile) # x shape (16, 512) → cache hit (dynamic'e geçildi)
LLM bağlamı#
LLM inference variable prompt + variable generation length. Dynamic shapes kritik:
- vLLM, TGI gibi engine'ler bunu sophisticated handle ediyor
- Direct torch.compile + dynamic=True ile basit cases handle edilebilir
Mark dynamic#
Belirli shape'leri "kesinlikle dynamic" işaretle:
torch._dynamo.mark_dynamic(x, 0) # batch dim dynamic torch._dynamo.mark_dynamic(x, 1) # seq dim dynamic
Bu, model üzerinde re-compile riski azaltır.
9. Production Deployment Patterns#
Pattern 1: Compile at server boot#
# server.py import torch model = load_model() model = torch.compile(model, mode="max-autotune") # Warm-up: representative shape'lerle compile for shape in [(1, 128), (1, 512), (1, 2048), (4, 128), (4, 512)]: dummy = torch.randn(*shape, device="cuda") _ = model(dummy) # Server hazır — production requests başlasın
Boot time 5-30 dakika (mode'a göre) ama sonraki requests max-speed.
Pattern 2: AOTInductor (ahead-of-time)#
PyTorch 2.4+'de AOTInductor: compile et, serialize et, başka makinede de yükle.
# Compile + save (development) torch._inductor.aot_compile(model, x, options={"aot_inductor.output_path": "model.so"}) # Production'da yükle model_pt = torch.jit.load("model.so")
Production deployment için ideal: compile bir kez, deploy her makine.
Pattern 3: Selective compilation#
Tam modeli compile etmek yerine kritik path compile et:
class HybridModel(torch.nn.Module): def __init__(self): super().__init__() self.encoder = torch.compile(Encoder()) # hot path self.decoder = Decoder() # eager self.special_op = SpecialOp() # custom, eager def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return self.special_op(decoded)
10. LLM Inference'ta torch.compile#
Vanilla torch.compile#
from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") model = torch.compile(model, mode="reduce-overhead")
Generation speedup: ~%30-50.
Sorunlar#
- KV cache management — graph break risk
- Dynamic shape prompt vs generation
- Sampling Python code — graph break
Modern çözümler#
Hugging Face transformers 2024'ten itibaren torch.compile-aware:
- KV cache structured Tensor olarak (graph-friendly)
- Sampling Triton kernel ile (compile-friendly)
vLLM ile karşılaştırma#
vLLM kendi optimization stack'ini kullanıyor (PagedAttention + Triton kernels + CUDA graphs). torch.compile'a göre çok daha hızlı LLM inference için.
torch.compile + vLLM (vLLM 0.5+): vLLM zaten optimize, torch.compile küçük ek kazanç verir. Modül 35 detayda.
Önerilen#
- Training: torch.compile kullan (~%20-30 speedup)
- Inference: vLLM, TGI, SGLang gibi specialized engine'ler. torch.compile fallback
11. Compile Time vs Runtime Trade-off#
| Mode | Compile time | Runtime speedup | Best for |
|---|---|---|---|
| Eager | 0 | 1.0x | Development, small models |
| Default | 5-30 sec | 1.3-1.5x | Training, general |
| Reduce-overhead | 30 sec - 2 min | 1.5-2.5x | Inference, latency |
| Max-autotune | 5-15 min | 2-3.5x | Production batched |
| AOTInductor | Pre-compile | 2-3x (loaded) | Distributed deployment |
Break-even analizi#
# Eager: 5 ms per inference # Compile (default): 8 sec compile, 3.5 ms per inference # Saving per inference: 1.5 ms # Break-even: 8 sec / 1.5 ms = 5333 inferences
5K request altında compile maliyetli. 5K üstü kazançlı.
Production scenario#
Average serving: 1M+ request per day → max-autotune kolayca pays off (3 min compile vs günlerce uptime).
Development scenario#
Her change'de 3 min compile = workflow killer. Default mode veya eager ile iterate.
12. vLLM, TGI ile Etkileşim#
LLM inference engines (Modül 35'te detay) torch.compile ile farklı ilişkiye sahip:
vLLM#
- Custom CUDA kernels (PagedAttention, fused operations)
- Triton kernels custom-written
- torch.compile opsiyonel (vLLM 0.5+)
- Avantaj: vLLM-specific optimization (batching, KV cache)
TGI (Text Generation Inference)#
- Hugging Face'in inference engine'i
- torch.compile kısmen kullanılıyor
- Custom Rust runtime + CUDA kernels
SGLang#
- RadixAttention için custom
- torch.compile entegrasyonu
TensorRT-LLM#
- NVIDIA'nın specialized inference framework'ü
- torch.compile DEĞİL — kendi compilation stack'i (TensorRT)
- En agresif optimization but NVIDIA-lock-in
Karar matrisi#
| Use case | Recommended |
|---|---|
| LLM training | torch.compile + DeepSpeed/FSDP |
| LLM inference (production) | vLLM (NVIDIA), llama.cpp (CPU/Apple), MLX |
| LLM inference (NVIDIA-only, max perf) | TensorRT-LLM |
| LLM inference (research, flexibility) | HF transformers + torch.compile |
| Custom model (non-LLM) | torch.compile + AOTInductor |
13. Mini Egzersizler#
-
Compile modes karşılaştırma: Bir CNN modeli için her mode'un compile time + runtime tahmini ne?
-
Graph break debug: Modelindevar. Bu graph break yapar mı? Çözüm?
print(x.shape) -
Dynamic shapes: LLM inference, prompt length 64-4096 arası. dynamic=True mi None mı?
-
Production deploy: 1M req/day server için optimal mode + warm-up stratejisi?
-
vLLM vs torch.compile: 8B Llama inference için 2026'da hangisi tercih?
Bu Derste Neler Öğrendik?#
✓ torch.compile 3 katmanı: TorchDynamo + TorchInductor + Triton
✓ Dynamo Python bytecode capture — frame eval API
✓ Inductor optimization: lowering, fusion, scheduling
✓ Triton backend — auto-generated GPU kernels
✓ torch.fx: symbolic graph manipulation (independent + complementary)
✓ Compile modes: default, reduce-overhead, max-autotune
✓ Graph breaks: detect (TORCH_LOGS), 5+ çözüm pattern
✓ Dynamic shapes — 2.1+ improvements, LLM critical
✓ Production patterns: boot warm-up, AOTInductor, selective compile
✓ LLM inference'ta: vLLM/TGI ile karşılaştırma
✓ Compile time trade-off — break-even analizi
Sıradaki Ders#
5.2 — Mixed Precision Training: BF16, FP16, FP8 + autocast & GradScaler
Modül 1.9'da numerik stabilite temellerini gördük. Şimdi production mixed precision: autocast pattern'leri, GradScaler edge case'leri, FP8 (Hopper/Blackwell) early adoption, gradient norm monitoring.
Frequently Asked Questions
Several technical advances converged: (1) **Python frame eval API** (PEP 523, 3.6+) enabled Dynamo's bytecode capture. (2) **Triton** matured in 2019, GPU kernel codegen became practical. (3) **TorchScript and LazyTensor experience** — learned what works, what doesn't. (4) **Meta's strategic decision**: dedicate PyTorch 2.0 brand to this. Previous attempts (TorchScript, TorchTrace) failed, so PyTorch 2.0 was a fresh start.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Module 0: Course Framework & Workshop Setup
Who Is an LLM Engineer? The AI Engineering Career Ladder from Junior to Staff
Start LearningModule 0: Course Framework & Workshop Setup
Course Philosophy: Why This Path, Why This Order — The Skeleton of an 8-Month Curriculum
Start LearningModule 0: Course Framework & Workshop Setup