İçeriğe geç

torch.compile ve torch.fx: Graph Capture, JIT Compilation ve Production Optimization

PyTorch 2.0+'ın game-changer feature'ı torch.compile derinlemesine: TorchDynamo + TorchInductor + Triton akışı, FX graph manipulation, compile modes (default/reduce-overhead/max-autotune), graph breaks debugging, dynamic shapes, production trade-off'lar. Modül 2.5'in production extension'ı.

Şükrü Yusuf KAYA
60 dakikalık okuma
İleri
torch.compile ve torch.fx: Graph Capture, JIT Compilation ve Production Optimization
⚡ PyTorch'un 'mucize katmanı'
Modül 2.5'te PyTorch eager vs JAX static graph karşılaştırmasını yaptık. torch.compile (PyTorch 2.0, Mart 2023) ikisinin de avantajını birleştirdi: eager Python'u arkadan compile ederek static graph performansı. 60 dakika sonra: TorchDynamo Python bytecode'u nasıl yakalıyor, TorchInductor neden Triton'a derliyor, graph break'leri nasıl debug edersin — production'da torch.compile sahibi olacaksın.

Ders Haritası (Detaylı)#

  1. torch.compile'ın 3 katmanı: Dynamo, Inductor, Triton
  2. TorchDynamo: Python bytecode capture
  3. TorchInductor: graph → optimized kernels
  4. Triton backend: GPU code generation
  5. torch.fx: symbolic graph manipulation
  6. Compile modes: default vs reduce-overhead vs max-autotune
  7. Graph breaks: tespit, debug, çözüm
  8. Dynamic shapes: 2.1+ improvements
  9. Production deployment patterns
  10. LLM inference'ta torch.compile
  11. Compile time vs runtime trade-off
  12. vLLM, TGI ile etkileşim

1. torch.compile'ın 3 Katmanı#

torch.compile tek bir şey değil — bir stack:
Python code ↓ [1] TorchDynamo — Python bytecode → FX graph ↓ [2] TorchInductor — FX graph → optimized lower-level IR ↓ [3] Triton (default) — IR → GPU kernels ↓ GPU

Niye katmanlı?#

Her katman bir abstraction level:
  • Dynamo: "ne yapıyor?" (graph capture)
  • Inductor: "nasıl optimize ederim?" (fusion, scheduling)
  • Triton: "GPU'da nasıl çalıştırırım?" (kernel codegen)

Alternative backends#

Inductor default backend'tir; başkaları da var:
  • nvFuser (NVIDIA): NVIDIA-specific optimization
  • OpenXLA: Google XLA-based
  • AOTInductor: ahead-of-time compilation (production ready)
  • TensorRT: NVIDIA TensorRT integration (2024+)

Mode flag#

@torch.compile(backend="inductor") # default @torch.compile(backend="aot_eager") # debugging @torch.compile(backend="cudagraphs") # CUDA graphs only

2. TorchDynamo — Python Bytecode Capture#

Problem: Python kodu nasıl static graph'a çevrilir?
Önceki denemeler (TorchScript, TorchTrace) kısmen başarılıydı:
  • TorchScript: Python subset (önce yazılması zor)
  • TorchTrace: bir input ile trace (control flow kayboluyor)

Dynamo yaklaşımı#

CPython interpreter'ın frame evaluation hook'unu kullan. Python bytecode'u runtime'da intercept et, graph'a çevir.

Akış#

1. Python function call (örn. model.forward()) 2. Dynamo bytecode'u inspect eder 3. Tensor operations'i FX graph'a aktarır 4. Non-tensor Python (control flow, prints) "graph break" oluşturur 5. Graph'lar compile edilir, Python kalanı eager kalır

Graph breaks#

Dynamo bazı şeyleri graph'a koyamaz:
  • print()
    ,
    logging
  • Python data structure modifications
  • Custom Python class methods (bazı durumlar)
  • Numpy operations
  • Custom autograd Function (kısmen)
Bu durumda graph break olur — function ikiye bölünür: önce compiled, sonra eager.

Frame evaluation API#

Python 3.11+'de Dynamo daha iyi (frame eval API hooks geliştirildi). 3.10'da çalışır ama performance daha düşük.

3. TorchInductor — FX Graph → Optimized IR#

Dynamo FX graph üretir. Şimdi optimize et.

Inductor'un işlevleri#

  1. Lowering: high-level ops (aten.matmul) → low-level primitives
  2. Fusion: ardışık op'ları birleştir (mul + add → fused multiply-add)
  3. Memory planning: tensor lifetimes, optimal allocation
  4. Scheduling: GPU stream'lerde paralel execution
  5. Code generation: Triton (GPU) veya C++ (CPU) için kernel kodu

Operator fusion örnek#

y = x.relu() + x.relu() * 2
Naive: 3 separate kernel (relu, mul, add) → 3 memory round-trips. Inductor fusion: 1 kernel, in-register computation → 3x speedup.

Inductor IR#

Internal representation, FX'ten daha low-level. Loop nests, memory accesses açık.

Code generation#

GPU için Triton kernel üretir:
@triton.jit def fused_relu_mul_add(x_ptr, y_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < n x = tl.load(x_ptr + offsets, mask=mask) relu_x = tl.where(x > 0, x, 0) y = relu_x + relu_x * 2 tl.store(y_ptr + offsets, y, mask=mask)
Bu otomatik üretilen kod, hand-written'a yakın hız.

4. Triton Backend — GPU Code Generation#

Triton (Tillet 2019, OpenAI) — Python syntax ile CUDA kernel yazma. PyTorch torch.compile default GPU backend.

Niye Triton?#

  • Python syntax: CUDA C++'tan kolay
  • Auto-tuning: block size, num warps otomatik
  • JIT compilation: runtime'da derle
  • Performans: hand-tuned CUDA'ya yakın (%80-95)

Inductor'un Triton kullanımı#

Inductor pattern'leri tanıyıp template Triton kernel'a doldurur:
  • Element-wise ops (mul, add, relu)
  • Reductions (sum, mean, max)
  • Pointwise + reduction fused

Custom Triton (Modül 37)#

torch.compile'ın otomatik Triton'u yeterli olmazsa, manuel Triton yazarsın. Modül 37'de detayda.

CPU backend#

CPU için Inductor C++ kodu üretir, sonra OpenMP ile derler. GPU'dan daha az kazanç ama yine ~%20-50 speedup mümkün.

5. torch.fx — Symbolic Graph Manipulation#

torch.fx torch.compile'dan eskidir (PyTorch 1.8+). Bağımsız bir araç: PyTorch model'i symbolic graph olarak temsil et + manipulate et.

Fx GraphModule#

import torch import torch.fx as fx class MyModel(torch.nn.Module): def forward(self, x): return torch.relu(x * 2 + 1) model = MyModel() traced = fx.symbolic_trace(model) print(traced.graph) # graph(): # %x : [#users=1] = placeholder[target=x] # %mul : [#users=1] = call_function[target=mul](args=(x, 2)) # %add : [#users=1] = call_function[target=add](args=(mul, 1)) # %relu : [#users=1] = call_function[target=torch.relu](args=(add,)) # return relu

Manipulation#

Graph'ı inspect + modify edebilirsin:
for node in traced.graph.nodes: if node.op == "call_function" and node.target is torch.relu: node.target = torch.nn.functional.leaky_relu # ReLU'yu LeakyReLU yap traced.graph.lint() traced.recompile()

Use cases#

  1. Quantization: weight'leri INT8'e dönüştür
  2. Profiling: hangi node ne kadar süre alıyor
  3. Architecture search: graph mutation ile NAS
  4. Custom transformations: domain-specific optimization

torch.compile ile ilişki#

Dynamo internally fx graph üretir. Inductor onu lower-level IR'a çevirir. Sen torch.fx'i direkt kullanmadan compile'dan faydalanırsın — ama özel optimization istiyorsan torch.fx ile pre-processing yapabilirsin.

6. Compile Modes — default, reduce-overhead, max-autotune#

torch.compile(model) # default mode torch.compile(model, mode="default") # eşdeğer torch.compile(model, mode="reduce-overhead") # latency-critical torch.compile(model, mode="max-autotune") # max throughput

Default mode#

  • Hızlı compile time (~10 sec)
  • Moderate optimization
  • Compatible most workloads
  • Good baseline

reduce-overhead#

  • CUDA graphs kullanır
  • Python overhead minimize (kernel launch overhead düşür)
  • Küçük batch / latency-critical inference için ideal
  • Compile time daha uzun (~30 sec)
  • Memory %20-50 artar

max-autotune#

  • En agresif optimization
  • Triton kernel'lar için autotuning (block size, num warps)
  • Compile time çok uzun (~5-15 dakika!)
  • En yüksek throughput
  • Production batched inference için ideal

Decision matrix#

SenaryoMode
Geliştirme, hızlı iterationdefault
Production trainingdefault veya reduce-overhead
Production inference (single request)reduce-overhead
Production inference (batched)max-autotune
Latency SLA criticalreduce-overhead
Throughput max'lememax-autotune
Compile time tolere edilemezdefault
python
import torch
import time
 
torch.set_float32_matmul_precision("high")
 
class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(1024, 4096), torch.nn.ReLU(),
torch.nn.Linear(4096, 4096), torch.nn.ReLU(),
torch.nn.Linear(4096, 1024),
)
def forward(self, x):
return self.layers(x)
 
device = "cuda"
model = MLP().to(device)
x = torch.randn(32, 1024, device=device)
 
# Eager baseline
def bench(fn, n_warmup=10, n=100):
for _ in range(n_warmup): fn()
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n): fn()
torch.cuda.synchronize()
return (time.perf_counter() - start) / n * 1000
 
t_eager = bench(lambda: model(x))
print(f"Eager: {t_eager:.3f} ms")
 
# Default
model_d = torch.compile(model, mode="default")
t_d = bench(lambda: model_d(x))
print(f"Default: {t_d:.3f} ms ({t_eager/t_d:.2f}x)")
 
# Reduce-overhead
model_r = torch.compile(model, mode="reduce-overhead")
t_r = bench(lambda: model_r(x))
print(f"Reduce-overhead: {t_r:.3f} ms ({t_eager/t_r:.2f}x)")
 
# Max-autotune (3-10 dakika compile!)
model_m = torch.compile(model, mode="max-autotune")
t_m = bench(lambda: model_m(x))
print(f"Max-autotune: {t_m:.3f} ms ({t_eager/t_m:.2f}x)")
 
# H100 typical:
# Eager: 0.42 ms
# Default: 0.31 ms (1.35x)
# Reduce-overhead: 0.18 ms (2.33x)
# Max-autotune: 0.14 ms (3.0x)
Compile modes karşılaştırma benchmark.

7. Graph Breaks — Tespit ve Çözüm#

Dynamo bazı kodları graph'a koyamaz → graph break.

Yaygın sebepler#

  1. print()
    ,
    logging.info()
    : side effect
  2. Python data structure manipulation: dict.pop(), list.append()
  3. Custom Python class kullanımı (some)
  4. numpy
    ops
  5. torch.no_grad()
    context manager (eskiden, 2.1+ improved)
  6. Item access (
    tensor.item()
    )
  7. Boolean tensor in if (
    if x > 0
    )
  8. Custom autograd Function (mostly OK in 2.4+)

Etki#

Graph break = function ikiye böler:
  • Önceki kısım: compiled
  • Break: eager (slow)
  • Sonra: compiled
Çok graph break = neredeyse eager = compile yararsız.

Detection#

import torch._dynamo as dynamo dynamo.config.verbose = True dynamo.config.suppress_errors = False # Detailed graph break log import os os.environ["TORCH_LOGS"] = "graph_breaks"
veya programmatik:
explanation = dynamo.explain(model, x) print(explanation) # Hangi node'da break, ne kadar fragmenta olduğu

Çözüm pattern'leri#

  1. print() yerine logger callback (compile-after)
  2. Tensor.item() kaldır — symbolic shape kullan
  3. if branches: torch.where() ile dispatch
  4. Custom op'lar:
    torch.compile.disable
    veya
    torch._dynamo.disable
    ile escape
  5. Custom autograd:
    @torch.library.custom_op
    ile register

Practical workflow#

# 1. Compile et model_c = torch.compile(model) # 2. Graph break'leri log'la os.environ["TORCH_LOGS"] = "graph_breaks" model_c(x) # 3. Identify et, çözüm uygula # 4. Recompile + test

8. Dynamic Shapes — 2.1+ Improvements#

PyTorch 2.0'da torch.compile her unique input shape için re-compile ediyordu. Bu LLM inference için fatal (her sequence length farklı).

2.1+ dynamic shapes#

torch.compile(model, dynamic=True) # tüm shape'ler dynamic torch.compile(model, dynamic=False) # static (recompile per shape) torch.compile(model, dynamic=None) # auto (default)
dynamic=None
(default): Dynamo ilk birkaç çağrıda shape'leri "specialize" eder, sonra dynamic'e geçer.

Symbolic shapes#

Dynamo shape'leri integer symbols olarak tutar. Örn.
s0 = batch_size, s1 = seq_len
.
@torch.compile def model(x): return x.sum(dim=-1) # x shape (4, 128) → ilk compile # x shape (8, 256) → re-compile (dynamic=None ile) # x shape (16, 512) → cache hit (dynamic'e geçildi)

LLM bağlamı#

LLM inference variable prompt + variable generation length. Dynamic shapes kritik:
  • vLLM, TGI gibi engine'ler bunu sophisticated handle ediyor
  • Direct torch.compile + dynamic=True ile basit cases handle edilebilir

Mark dynamic#

Belirli shape'leri "kesinlikle dynamic" işaretle:
torch._dynamo.mark_dynamic(x, 0) # batch dim dynamic torch._dynamo.mark_dynamic(x, 1) # seq dim dynamic
Bu, model üzerinde re-compile riski azaltır.

9. Production Deployment Patterns#

Pattern 1: Compile at server boot#

# server.py import torch model = load_model() model = torch.compile(model, mode="max-autotune") # Warm-up: representative shape'lerle compile for shape in [(1, 128), (1, 512), (1, 2048), (4, 128), (4, 512)]: dummy = torch.randn(*shape, device="cuda") _ = model(dummy) # Server hazır — production requests başlasın
Boot time 5-30 dakika (mode'a göre) ama sonraki requests max-speed.

Pattern 2: AOTInductor (ahead-of-time)#

PyTorch 2.4+'de AOTInductor: compile et, serialize et, başka makinede de yükle.
# Compile + save (development) torch._inductor.aot_compile(model, x, options={"aot_inductor.output_path": "model.so"}) # Production'da yükle model_pt = torch.jit.load("model.so")
Production deployment için ideal: compile bir kez, deploy her makine.

Pattern 3: Selective compilation#

Tam modeli compile etmek yerine kritik path compile et:
class HybridModel(torch.nn.Module): def __init__(self): super().__init__() self.encoder = torch.compile(Encoder()) # hot path self.decoder = Decoder() # eager self.special_op = SpecialOp() # custom, eager def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return self.special_op(decoded)

10. LLM Inference'ta torch.compile#

Vanilla torch.compile#

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") model = torch.compile(model, mode="reduce-overhead")
Generation speedup: ~%30-50.

Sorunlar#

  1. KV cache management — graph break risk
  2. Dynamic shape prompt vs generation
  3. Sampling Python code — graph break

Modern çözümler#

Hugging Face transformers 2024'ten itibaren torch.compile-aware:
  • KV cache structured Tensor olarak (graph-friendly)
  • Sampling Triton kernel ile (compile-friendly)

vLLM ile karşılaştırma#

vLLM kendi optimization stack'ini kullanıyor (PagedAttention + Triton kernels + CUDA graphs). torch.compile'a göre çok daha hızlı LLM inference için.
torch.compile + vLLM (vLLM 0.5+): vLLM zaten optimize, torch.compile küçük ek kazanç verir. Modül 35 detayda.

Önerilen#

  • Training: torch.compile kullan (~%20-30 speedup)
  • Inference: vLLM, TGI, SGLang gibi specialized engine'ler. torch.compile fallback

11. Compile Time vs Runtime Trade-off#

ModeCompile timeRuntime speedupBest for
Eager01.0xDevelopment, small models
Default5-30 sec1.3-1.5xTraining, general
Reduce-overhead30 sec - 2 min1.5-2.5xInference, latency
Max-autotune5-15 min2-3.5xProduction batched
AOTInductorPre-compile2-3x (loaded)Distributed deployment

Break-even analizi#

# Eager: 5 ms per inference # Compile (default): 8 sec compile, 3.5 ms per inference # Saving per inference: 1.5 ms # Break-even: 8 sec / 1.5 ms = 5333 inferences
5K request altında compile maliyetli. 5K üstü kazançlı.

Production scenario#

Average serving: 1M+ request per day → max-autotune kolayca pays off (3 min compile vs günlerce uptime).

Development scenario#

Her change'de 3 min compile = workflow killer. Default mode veya eager ile iterate.

12. vLLM, TGI ile Etkileşim#

LLM inference engines (Modül 35'te detay) torch.compile ile farklı ilişkiye sahip:

vLLM#

  • Custom CUDA kernels (PagedAttention, fused operations)
  • Triton kernels custom-written
  • torch.compile opsiyonel (vLLM 0.5+)
  • Avantaj: vLLM-specific optimization (batching, KV cache)

TGI (Text Generation Inference)#

  • Hugging Face'in inference engine'i
  • torch.compile kısmen kullanılıyor
  • Custom Rust runtime + CUDA kernels

SGLang#

  • RadixAttention için custom
  • torch.compile entegrasyonu

TensorRT-LLM#

  • NVIDIA'nın specialized inference framework'ü
  • torch.compile DEĞİL — kendi compilation stack'i (TensorRT)
  • En agresif optimization but NVIDIA-lock-in

Karar matrisi#

Use caseRecommended
LLM trainingtorch.compile + DeepSpeed/FSDP
LLM inference (production)vLLM (NVIDIA), llama.cpp (CPU/Apple), MLX
LLM inference (NVIDIA-only, max perf)TensorRT-LLM
LLM inference (research, flexibility)HF transformers + torch.compile
Custom model (non-LLM)torch.compile + AOTInductor

13. Mini Egzersizler#

  1. Compile modes karşılaştırma: Bir CNN modeli için her mode'un compile time + runtime tahmini ne?
  2. Graph break debug: Modelinde
    print(x.shape)
    var. Bu graph break yapar mı? Çözüm?
  3. Dynamic shapes: LLM inference, prompt length 64-4096 arası. dynamic=True mi None mı?
  4. Production deploy: 1M req/day server için optimal mode + warm-up stratejisi?
  5. vLLM vs torch.compile: 8B Llama inference için 2026'da hangisi tercih?

Bu Derste Neler Öğrendik?#

torch.compile 3 katmanı: TorchDynamo + TorchInductor + Triton ✓ Dynamo Python bytecode capture — frame eval API ✓ Inductor optimization: lowering, fusion, scheduling ✓ Triton backend — auto-generated GPU kernels ✓ torch.fx: symbolic graph manipulation (independent + complementary) ✓ Compile modes: default, reduce-overhead, max-autotune ✓ Graph breaks: detect (TORCH_LOGS), 5+ çözüm pattern ✓ Dynamic shapes — 2.1+ improvements, LLM critical ✓ Production patterns: boot warm-up, AOTInductor, selective compile ✓ LLM inference'ta: vLLM/TGI ile karşılaştırma ✓ Compile time trade-off — break-even analizi

Sıradaki Ders#

5.2 — Mixed Precision Training: BF16, FP16, FP8 + autocast & GradScaler Modül 1.9'da numerik stabilite temellerini gördük. Şimdi production mixed precision: autocast pattern'leri, GradScaler edge case'leri, FP8 (Hopper/Blackwell) early adoption, gradient norm monitoring.

Sık Sorulan Sorular

Birkaç teknik gelişme bir araya geldi: (1) **Python frame eval API** (PEP 523, 3.6+) Dynamo'nun bytecode capture'ı mümkün kıldı. (2) **Triton** 2019'da olgunlaştı, GPU kernel codegen pratik oldu. (3) **TorchScript ve LazyTensor deneyimi** — what works, what doesn't öğrenildi. (4) **Meta'nın stratejik kararı**: PyTorch 2.0 brand'ini buna ayırmak. Önceki yaklaşımlar (TorchScript, TorchTrace) başarısız olduğu için PyTorch 2.0 fresh start.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler