Skip to content

Custom GPU Kernels with Triton: Softmax, Matmul, FlashAttention Mini from Scratch

Triton's secret of GPU programming with Python syntax: programming model (program_id, block_size, autotune), softmax kernel from scratch, matmul tiling, FlashAttention's block-wise mini implementation, performance tuning. Practical foundation for Module 37 (CUDA/Triton deep dive).

Şükrü Yusuf KAYA
60 min read
Advanced
Triton ile Custom GPU Kernels: Softmax, Matmul, FlashAttention Mini Sıfırdan
🔥 GPU programming'in Python devrimi
CUDA C++ yazmak iki ay öğrenmek demek. Triton ile Python syntax kullanarak hand-tuned CUDA'ya yakın performansta kernel yazabilirsin — 1 hafta. FlashAttention, Mamba, ColPali, vLLM custom kernel'ları — hepsi Triton. 60 dakika sonra: kendi softmax kernel'ını yazma, matmul tiling matematiği, FlashAttention block-wise mini implementasyon — hepsini bileceksin.

Ders Haritası#

  1. Triton nedir, niye?
  2. GPU programming model basics
  3. İlk Triton kernel: vector add
  4. Memory access patterns: coalesced loads
  5. Softmax kernel sıfırdan
  6. Matmul tiling matematiği
  7. Autotune ile hyperparameter search
  8. FlashAttention block-wise mini implementation
  9. PyTorch entegrasyonu: torch.autograd.Function ile
  10. Benchmark ve profiling
  11. Production patterns
  12. Modül 37'ye köprü

1. Triton Nedir, Niye?#

Triton (Philippe Tillet, OpenAI, 2019). Python-syntax GPU kernel language + JIT compiler.

Niye CUDA C++ değil?#

CUDA C++ zorlukları:
  • Düşük seviye memory management (shared, register, global)
  • Thread block / warp / tile mathematics manuel
  • Compilation NVCC ile yavaş
  • Debugging zor
  • Verification (correctness) çok zaman alıyor

Triton avantajları#

  1. Python syntax: işin yarısı zaten Python-friendly
  2. Auto-tuning: block size, num warps Triton seçiyor
  3. JIT compilation: runtime'da Python'dan PTX'e
  4. Pythonic memory model: pointer + offset (CUDA pointer arithmetic kadar değil)
  5. Performance: hand-tuned CUDA'ya %80-95 yakın

Adoption#

  • FlashAttention (Dao 2022): Triton ile yazıldı
  • vLLM PagedAttention: Triton
  • PyTorch torch.compile: Inductor → Triton (default backend)
  • OpenAI internal: birçok production kernel

Türk perspektifi#

Triton 2026'da bilinen ama yaygın kullanılan değil. Akademik araştırma, advanced LLM mühendisliği için zorunlu. Frontier-aspirant Türk şirketleri (DeepSeek-tarzı) öğrenmeli.

2. GPU Programming Model Basics#

GPU'nun hierarchical paralellik modeli:

Hierarchy#

Grid (kernel launch) └── Blocks (paralel) └── Warps (32 thread) └── Threads (paralel SM içinde)

Memory hierarchy#

Global memory: GBs, slow (~500 GB/s) L2 cache: MBs, faster Shared memory: KBs per block, fast (~10 TB/s effective) Registers: per thread, fastest

Triton abstraction#

Triton thread'leri soyutluyor; sen block'lar üzerinden çalışıyorsun:
  • tl.program_id(axis)
    : bu block'un grid'deki ID'si
  • tl.arange(0, BLOCK)
    : block içinde paralel offset'ler
  • Tüm operasyonlar block-level, thread'leri Triton handle ediyor
Bu, CUDA C++'tan çok daha basit mental model.

Block size tuning#

Hyperparameter: kaç element/block.
  • Küçük (128): fazla overhead, az parallelism
  • Büyük (4096): registers/shared memory dolar, occupancy düşer
  • Optimal genelde 1024-2048 GPU'ya göre
Triton autotune bunu otomatik bulur.

3. İlk Triton Kernel: Vector Add#

import torch import triton import triton.language as tl @triton.jit def add_kernel( x_ptr, # pointer to x y_ptr, # pointer to y out_ptr, # pointer to output n, # toplam element sayısı BLOCK_SIZE: tl.constexpr, # compile-time constant ): # Bu block'un ID'si pid = tl.program_id(axis=0) # Block içindeki offset'ler offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # Boundary mask (son block için) mask = offsets < n # Load x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) # Compute out = x + y # Store tl.store(out_ptr + offsets, out, mask=mask) def add(x, y): out = torch.empty_like(x) n = x.numel() # Grid: kaç block lazım? grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) add_kernel[grid](x, y, out, n, BLOCK_SIZE=1024) return out # Test x = torch.randn(10000, device="cuda") y = torch.randn(10000, device="cuda") out = add(x, y) print(torch.allclose(out, x + y)) # True

Analiz#

  • @triton.jit
    : function compile edilecek
  • tl.constexpr
    : compile-time'da bilinen sabit (block size)
  • tl.program_id
    : hangi block çalışıyor
  • tl.load/store
    : memory operations with mask
  • Grid lambda: cdiv(n, BLOCK_SIZE) block'a bölünür

Performance#

Vector add: memory-bound (compute trivial, memory transfer bottleneck). Triton bu kadar basit problemde PyTorch'la eşit. Asıl fark karmaşık kernel'larda.

4. Memory Access Patterns — Coalesced Loads#

GPU memory çok geniş (bus width 1024-bit). En verimli pattern: coalesced (ardışık thread'ler ardışık memory).

Coalesced vs Non-coalesced#

Coalesced (good): Thread 0 → addr 0 Thread 1 → addr 1 Thread 2 → addr 2 ...
Non-coalesced (bad): Thread 0 → addr 0 Thread 1 → addr 32 Thread 2 → addr 64 ...
İkincisinde her thread ayrı 32-byte transaction → bandwidth düşer.

Triton'da pattern#

offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets, mask=mask)
tl.arange(0, BLOCK_SIZE)
ardışık integer'lar → ardışık offsets → coalesced.

2D access#

Matris üzerinde:
# Row-major matris, M × N row_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # 2D ptr mat_ptr = X_ptr + row_offsets[:, None] * N + col_offsets[None, :] data = tl.load(mat_ptr, mask=...)
[:, None]
ve
[None, :]
broadcasting ile 2D pointer grid yaratıyoruz.

Modern Triton: block pointer#

PyTorch 2.0+ Triton'da
tl.make_block_ptr
daha temiz:
block_ptr = tl.make_block_ptr( base=X_ptr, shape=(M, N), strides=(N, 1), offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0), ) data = tl.load(block_ptr, boundary_check=(0, 1))
Hem okunaklı hem performant.

5. Softmax Kernel Sıfırdan#

Softmax PyTorch'tan 2-3x daha hızlı yazılabilir Triton ile (fused load-compute-store).
@triton.jit def softmax_kernel( output_ptr, input_ptr, row_stride, # bytes between rows n_cols, # number of columns BLOCK_SIZE: tl.constexpr, ): # Her block bir satır işler row_idx = tl.program_id(0) row_start_ptr = input_ptr + row_idx * row_stride col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols # Load row row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Numerical stability: subtract max row_max = tl.max(row, axis=0) row_minus_max = row - row_max # Exp + sum numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) # Normalize softmax_output = numerator / denominator # Store output_row_start_ptr = output_ptr + row_idx * row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) def softmax(x): n_rows, n_cols = x.shape BLOCK_SIZE = triton.next_power_of_2(n_cols) output = torch.empty_like(x) softmax_kernel[(n_rows,)]( output, x, x.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, ) return output # Test x = torch.randn(1000, 512, device="cuda") out = softmax(x) print(torch.allclose(out, torch.softmax(x, dim=-1), atol=1e-5)) # True

Niye PyTorch'tan hızlı?#

PyTorch native softmax:
1. exp(x - max(x)) — yeni tensor, full memory write 2. sum(exp_x) — reduction, başka tensor 3. div(exp_x, sum) — division, başka tensor
3 memory pass. Triton fused softmax:
1. Load row 2. Compute (max, exp, sum, divide) — registers 3. Store
1 memory pass → 3x daha az memory traffic → 2-3x daha hızlı.

Profil#

Triton softmax 512-col row için ~5 GB/s effective. cuBLAS softmax ~10 GB/s (still better, NVIDIA optimized). Custom kernel cuBLAS'ı tam yakalayamaz ama %50-70 erişebilir.

6. Matmul Tiling Matematiği#

Matrix multiplication
C = A @ B
where A is (M, K), B is (K, N), C is (M, N).

Naive#

for i in range(M): for j in range(N): for k in range(K): C[i, j] += A[i, k] * B[k, j]
M, K, N ~4096 → 68B operations. CPU: dakikalar. GPU naive: saniyeler.

Tiling — block-level paralel#

Matrix'i tile'lara böl, her block bir tile işler:
A_tile = A[i*BLOCK_M:(i+1)*BLOCK_M, k*BLOCK_K:(k+1)*BLOCK_K] B_tile = B[k*BLOCK_K:(k+1)*BLOCK_K, j*BLOCK_N:(j+1)*BLOCK_N] C_tile += A_tile @ B_tile
Block içinde compute paralel, block'lar arası paralel.

Triton matmul kernel#

@triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): # Load tiles a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k) # Accumulate accumulator += tl.dot(a, b) # Advance K a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Store result offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn tl.store(c_ptrs, accumulator, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))

Pratik performance#

Bu naive Triton matmul (1024×1024×1024): ~5 ms. cuBLAS (PyTorch native): ~1.5 ms.
3x slower — niye? Optimal kernel:
  • Shared memory caching
  • Register-level tiling
  • Tensor Core utilization (BF16/FP16)
  • Software pipelining
cuBLAS yıllarca tuned. Triton manuel optimization ile %70-80'ine ulaşabilir.

Realistic use#

Matmul için cuBLAS use et (PyTorch
@
operator). Custom kernel mantıklı fused operations (örn. matmul + bias + relu + dropout fused).

7. Autotune ile Hyperparameter Search#

Block size, num warps, num stages — kernel performansını dramatically etkiliyor. Triton autotune ile otomatik en iyiyi bulur.
@triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32}, num_warps=8), triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4), ], key=["M", "N", "K"], # Bu key'lere göre cache'le ) @triton.jit def matmul_kernel(...): ...

Behavior#

İlk çağrıda her config çalıştırılır, en hızlısı seçilir. Bu config (M, N, K) için cache'lenir. Sonraki çağrılarda direkt en iyisi.

num_warps#

Her block'ta kaç warp (32 thread). Tipik 4, 8, 16.
  • Az: low occupancy
  • Çok: register pressure

num_stages#

Pipelining depth. Modern GPU'larda 2-4 optimal.

Auto-tuning cost#

İlk run yavaş (~1-10 saniye search). Sonra cached. Production'da:
  • Warmup'ta autotune koş
  • Cache disk'e kaydedilebilir (
    triton.runtime.cache
    )

8. FlashAttention Block-Wise Mini Implementation#

FlashAttention (Dao 2022): attention'ı block-wise hesaplıyor, O(N²) memory → O(N).

Klasik attention (slow + memory-heavy)#

S = Q @ K.T / sqrt(d) # (N, N) matrix — büyük! P = softmax(S) # (N, N) O = P @ V # (N, d)
Memory N=8192 için: 8192×8192×4 byte = 256 MB sadece S matrix.

FlashAttention insight#

Tüm S matrix'ini saklama gerekli değil. Online softmax ile block-wise işle, sonuçları incrementally birleştir.

Mini Triton implementation (illustrative)#

@triton.jit def flash_attention_kernel( Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, L_ptr, # running max, log-sum-exp N_HEAD, N_SEQ, D_HEAD, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): pid = tl.program_id(0) head_idx = tl.program_id(1) # Q block'unu yükle q_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M) Q_block = tl.load(Q_ptr + ...) # (BLOCK_M, D_HEAD) # Initialize accumulators m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) O_block = tl.zeros([BLOCK_M, D_HEAD], dtype=tl.float32) # K, V'yi block block dolaş for kv_start in range(0, N_SEQ, BLOCK_N): # K, V blocks K_block = tl.load(K_ptr + ...) V_block = tl.load(V_ptr + ...) # Compute S = Q @ K^T S = tl.dot(Q_block, K_block.T) / tl.sqrt(D_HEAD.to(tl.float32)) # Online softmax — running max + sum m_ij = tl.maximum(m_i, tl.max(S, axis=1)) p_ij = tl.exp(S - m_ij[:, None]) l_ij = l_i * tl.exp(m_i - m_ij) + tl.sum(p_ij, axis=1) # Update output O_block = O_block * (tl.exp(m_i - m_ij) * l_i / l_ij)[:, None] + tl.dot(p_ij, V_block) / l_ij[:, None] # Update running stats m_i = m_ij l_i = l_ij # Store tl.store(O_ptr + ..., O_block)

Real FlashAttention complexity#

Bu pedagojik simplification. Gerçek FlashAttention:
  • Forward + backward kernel
  • Causal masking
  • Variable-length sequences
  • Dropout
  • Bias support
  • ~2000 line Triton code

Performance#

Gerçek FlashAttention vs naive PyTorch attention:
  • 4-10x faster (sequence length'e göre)
  • O(N) memory (vs O(N²))
  • Numerically identical (no approximation)

LLM impact#

FlashAttention-2 ile Llama 3 70B fine-tuning 2-3x daha hızlı. KV cache + FlashAttention inference'ta da kritik (Modül 33).

9. PyTorch Entegrasyonu#

Triton kernel'i
torch.autograd.Function
ile sarmak (Modül 2.6'da gördük):
class TritonSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, x): out = torch.empty_like(x) n_rows, n_cols = x.shape BLOCK_SIZE = triton.next_power_of_2(n_cols) softmax_kernel[(n_rows,)]( out, x, x.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE ) ctx.save_for_backward(out) return out @staticmethod def backward(ctx, grad_output): out, = ctx.saved_tensors # Softmax backward: grad = out * (grad_output - sum(grad_output * out)) # Bu da Triton kernel olabilir grad_input = ... # custom backward kernel return grad_input # Module class TritonSoftmaxLayer(torch.nn.Module): def forward(self, x): return TritonSoftmax.apply(x)

torch.library.custom_op (modern)#

PyTorch 2.4+:
@torch.library.custom_op("mylib::triton_softmax", mutates_args=()) def triton_softmax(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) # ... Triton kernel ... return out @triton_softmax.register_fake def _(x): return torch.empty_like(x) @triton_softmax.register_autograd def _backward(ctx, grad_output): # ... return grad_input
torch.compile
ile fully composable custom op.

10. Benchmark ve Profiling#

Triton'un kendi benchmark utility'si#

@triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(8, 14)], line_arg="provider", line_vals=["triton", "pytorch"], line_names=["Triton", "PyTorch"], styles=[("blue", "-"), ("green", "-")], ylabel="GB/s", plot_name="softmax-perf", args={}, ) ) def benchmark(N, provider): x = torch.randn(1024, N, device="cuda", dtype=torch.float32) if provider == "triton": ms = triton.testing.do_bench(lambda: softmax(x)) elif provider == "pytorch": ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1)) gbps = 2 * x.numel() * x.element_size() / ms / 1e6 return gbps benchmark.run(show_plots=True, print_data=True)
Otomatik grafik çıkarır, performans karşılaştırması yapar.

Manuel timing#

import time # Warmup for _ in range(10): triton_kernel(...) torch.cuda.synchronize() # Timing start = time.perf_counter() for _ in range(100): triton_kernel(...) torch.cuda.synchronize() elapsed = (time.perf_counter() - start) / 100 * 1000 print(f"Triton: {elapsed:.3f} ms")

Profiling#

Nsight Compute Triton kernel'ları detaylı profile eder:
  • Occupancy
  • Memory throughput
  • Roofline analysis (compute-bound vs memory-bound)
Modül 37'de detayda.

11. Production Patterns#

1. Custom kernel ne zaman yazılır?#

  • Profil'de bir kernel %30+ time alıyorsa
  • PyTorch native + torch.compile yetmiyorsa
  • Algorithmic innovation (FlashAttention gibi)
  • Hardware-specific (Tensor Core, FP8)
%95'inde torch.compile yeter.

2. Caching kernel artifact#

import os os.environ["TRITON_CACHE_DIR"] = "/path/to/cache"
Compiled kernels disk'te cache'lenir. İkinci deployment hızlı boot.

3. Versioning#

Triton kernel'lar API değişimi ile bozulabilir. Production'da:
  • Triton version pin:
    triton==3.0.0
  • Backward compatibility test
  • Regression suite

4. Numerical verification#

Custom kernel için mutlaka PyTorch reference ile karşılaştırma:
def verify(triton_fn, pytorch_fn, x): out_triton = triton_fn(x) out_pytorch = pytorch_fn(x) rel_err = (out_triton - out_pytorch).abs().max() / out_pytorch.abs().max() assert rel_err < 1e-4, f"Mismatch: {rel_err}"

5. Edge cases#

Custom kernel'lar edge case'lerde çuvallar:
  • Empty input (N=0)
  • Single element
  • Non-contiguous tensors
  • Mixed dtype
Modül 37 production-grade pattern'leri derinleştiriyor.

12. Modül 37'ye Köprü#

Bu ders Triton'a giriş. Modül 37 (Custom CUDA ve Triton Kernels) çok daha derinleştiriyor:
  • FlashAttention v3 detaylı walkthrough
  • Quantization kernel'ları (Q4, AWQ, GPTQ)
  • Tensor Core matrix ops
  • MoE routing kernels
  • Custom backward implementation
  • Multi-GPU kernel patterns (cross-device synchronization)
  • Production debugging (Nsight Compute)
  • Performance modeling

Ön hazırlık checklist#

Modül 37'ye geçmeden önce:
  • Vector add Triton kernel yazıp PyTorch ile karşılaştırma
  • Softmax kernel'ı verify et
  • Naive matmul kernel + autotune
  • FlashAttention mini implementation deneme
  • Triton documentation skim (
    triton-lang.org
    )
  • OpenAI Triton tutorials (5-10 tutorial)

Pratik öneri#

Production projende Triton kernel yazmadan önce:
  1. PyTorch native + torch.compile dene
  2. Profile et
  3. Bottleneck'i identify et
  4. Sadece bottleneck custom yaz
%80 case'de bu adımlar custom kernel'a gerek bırakmaz.

13. Mini Egzersizler#

  1. Vector add benchmark: 10M element add, PyTorch vs Triton. Kim hızlı?
  2. Softmax bandwidth: 1024 row × 4096 col softmax, Triton kernel ile theoretical max GB/s'in % kaçı?
  3. Matmul size sweet spot: Hangi M, N, K boyutu için Triton custom matmul cuBLAS'a yaklaşır?
  4. FlashAttention vs vanilla: 8K sequence length attention. Memory ve compute farkı?
  5. Production karar: Llama 3 fine-tuning'de custom Triton kernel yazmaya değer mi?

Bu Derste Neler Öğrendik?#

Triton — Python-syntax GPU kernel language ✓ GPU programming model: grid, block, warp, thread hierarchy ✓ İlk kernel: vector add —
@triton.jit
pattern ✓ Memory coalescing — performance fundamental ✓ Softmax kernel sıfırdan — fused load-compute-store ✓ Matmul tiling matematiği — block-wise paralelism ✓ Autotune — block size, num_warps optimal search ✓ FlashAttention mini implementation — block-wise online softmax ✓ PyTorch entegrasyonu — autograd.Function + torch.library ✓ Benchmark + profiling — Triton testing utility ✓ Production patterns — caching, versioning, verification ✓ Modül 37'ye köprü — derin dalış için ön hazırlık

Sıradaki Ders#

5.6 —
torch.distributed
Derinleştirilmiş: DDP, FSDP, ZeRO Stages
5.4'te NCCL temellerini gördük. Şimdi production distributed training: DDP gradient bucketing detay, FSDP shard strategies (FULL_SHARD, SHARD_GRAD_OP, HYBRID), DeepSpeed ZeRO Stage 1/2/3 karşılaştırma. Modül 17 (Distributed Training) için son köprü.

Frequently Asked Questions

Three different use cases. **Triton**: for Python developers, 80-95% of hand-tuned CUDA performance, fast iteration. **CUDA C++**: highest performance, custom hardware features (Tensor Core, TMA), 10x more dev time. **CUTLASS** (NVIDIA library): matmul-specific template library, expert-level CUDA primitives. **Practical**: start with Triton. Compatible with PyTorch ecosystem. CUDA C++ only for extreme cases (FlashAttention v3 production). CUTLASS for NVIDIA internals.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content