Skip to content

Custom Triton Kernel Lab: Cross-Entropy + Ignore-Index — Unsloth-Style Speedup

PyTorch native \`F.cross_entropy(ignore_index=-100)\` one of LLM training's most-called kernels. Naïve implementation can be 30% faster with Triton. Cookbook Lab: fused logits + softmax + CE + grad → single kernel. Pattern Unsloth uses. 8B model FT throughput +15% on RTX 4090.

Şükrü Yusuf KAYA
38 min read
Advanced
Custom Triton Kernel Lab: Cross-Entropy + Ignore-Index — Unsloth-Style Speedup
python
# === Custom Fused Cross-Entropy Triton Kernel ===
# Inspired by Unsloth + Liger Kernel
import torch, triton, triton.language as tl
 
@triton.jit
def fused_ce_kernel(
logits_ptr, labels_ptr, loss_ptr, grad_ptr,
n, V, # n = batch*seq, V = vocab_size
IGNORE_INDEX: tl.constexpr,
BLOCK_V: tl.constexpr,
):
row_idx = tl.program_id(0)
label = tl.load(labels_ptr + row_idx)
 
# Eğer label == IGNORE_INDEX: skip
if label == IGNORE_INDEX:
tl.store(loss_ptr + row_idx, 0.0)
return
 
# Logits row
logits_row = logits_ptr + row_idx * V
 
# Online softmax (max + sum)
col_offsets = tl.arange(0, BLOCK_V)
row_max = -float("inf")
row_sum = 0.0
 
for v_block_start in range(0, V, BLOCK_V):
block_offsets = v_block_start + col_offsets
mask = block_offsets < V
block_logits = tl.load(logits_row + block_offsets, mask=mask, other=-float("inf"))
 
block_max = tl.max(block_logits, axis=0)
new_max = tl.maximum(row_max, block_max)
 
# Rescale running sum
row_sum = row_sum * tl.exp(row_max - new_max)
row_sum += tl.sum(tl.exp(block_logits - new_max), axis=0)
row_max = new_max
 
# Loss: logits[label] - log_sum_exp
logit_label = tl.load(logits_row + label)
log_z = tl.log(row_sum) + row_max
loss = log_z - logit_label
tl.store(loss_ptr + row_idx, loss)
 
# Gradient: softmax(logits) - one_hot(label) / batch_size
for v_block_start in range(0, V, BLOCK_V):
block_offsets = v_block_start + col_offsets
mask = block_offsets < V
block_logits = tl.load(logits_row + block_offsets, mask=mask, other=-float("inf"))
block_softmax = tl.exp(block_logits - log_z)
 
# Subtract one-hot
is_label = (block_offsets == label)
grad = (block_softmax - tl.where(is_label, 1.0, 0.0)) / n
 
tl.store(grad_ptr + row_idx * V + block_offsets, grad, mask=mask)
 
# Python wrapper
def fused_ce(logits, labels, ignore_index=-100):
n, V = logits.shape
loss = torch.empty(n, device=logits.device)
grad = torch.zeros_like(logits)
BLOCK_V = 1024
fused_ce_kernel[(n,)](
logits, labels, loss, grad,
n, V, IGNORE_INDEX=ignore_index, BLOCK_V=BLOCK_V,
)
valid_mask = labels != ignore_index
mean_loss = loss[valid_mask].mean()
return mean_loss, grad
 
# Bench (RTX 4090, batch=8, seq=4096, vocab=128K):
# torch.nn.functional.cross_entropy: 12.4 ms
# Custom Triton fused CE: 8.7 ms (-30%)
Custom Fused Cross-Entropy Triton kernel
✅ Teslim
  1. Bu kernel'ı RTX 4090'ında çalıştır + torch CE ile karşılaştır. 2) Numeric bit-exact testle doğrula. 3) Sonraki ders: 13.4 — Liger Kernel Tour.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content