İçeriğe geç

Triton Crash Course: Block Pointer + Autotune + Masks — 50 Satırda GPU Kernel

Triton (OpenAI, 2021) — CUDA kadar hızlı, Python kadar kolay GPU kernel framework'ü. \`@triton.jit\`, \`tl.program_id\`, \`tl.arange\`, block pointer arithmetic, autotune decorator, mask-based load/store, shared memory abstraction. RTX 4090'da Triton vector add → matmul → softmax kernel'larını sıfırdan yaz.

Şükrü Yusuf KAYA
36 dakikalık okuma
İleri
Triton Crash Course: Block Pointer + Autotune + Masks — 50 Satırda GPU Kernel
python
# === Triton Vector Add — "Hello World" ===
import torch
import triton
import triton.language as tl
 
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
 
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
out = x + y
tl.store(out_ptr + offsets, out, mask=mask)
 
def add(x, y):
out = torch.empty_like(x)
n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=1024)
return out
 
# Test
x = torch.randn(10_000_000, device="cuda")
y = torch.randn(10_000_000, device="cuda")
torch.testing.assert_close(add(x, y), x + y)
print("✅ Triton vector add OK")
Triton — Vector Add 'Hello World'
python
# === Triton Softmax Kernel (online, numerically stable) ===
@triton.jit
def softmax_kernel(out_ptr, in_ptr, in_row_stride, out_row_stride,
n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
row_start = in_ptr + row_idx * in_row_stride
 
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(row_start + col_offsets, mask=mask, other=-float("inf"))
 
# Online softmax
row_max = tl.max(row, axis=0)
row_minus_max = row - row_max
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
out = numerator / denominator
 
out_start = out_ptr + row_idx * out_row_stride
tl.store(out_start + col_offsets, out, mask=mask)
 
def softmax(x):
out = torch.empty_like(x)
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
softmax_kernel[(n_rows,)](
out, x, x.stride(0), out.stride(0), n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
return out
 
# Bench (RTX 4090)
x = torch.randn(1024, 32_000, device="cuda")
# Triton softmax: 0.18 ms
# torch.softmax: 0.21 ms (close but Triton slight edge)
Triton softmax kernel
✅ Teslim
  1. Yukarıdaki 2 kernel'ı RTX 4090'ında çalıştır. 2) torch baseline ile karşılaştır. 3) Sonraki ders: 13.3 — Custom Triton Kernel Lab.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler