Custom Triton Kernel Lab: Cross-Entropy + Ignore-Index — Unsloth-Style Speedup
PyTorch native \`F.cross_entropy(ignore_index=-100)\` one of LLM training's most-called kernels. Naïve implementation can be 30% faster with Triton. Cookbook Lab: fused logits + softmax + CE + grad → single kernel. Pattern Unsloth uses. 8B model FT throughput +15% on RTX 4090.
Şükrü Yusuf KAYA
38 min read
Advancedpython
# === Custom Fused Cross-Entropy Triton Kernel ===# Inspired by Unsloth + Liger Kernelimport torch, triton, triton.language as tl @triton.jitdef fused_ce_kernel( logits_ptr, labels_ptr, loss_ptr, grad_ptr, n, V, # n = batch*seq, V = vocab_size IGNORE_INDEX: tl.constexpr, BLOCK_V: tl.constexpr,): row_idx = tl.program_id(0) label = tl.load(labels_ptr + row_idx) # Eğer label == IGNORE_INDEX: skip if label == IGNORE_INDEX: tl.store(loss_ptr + row_idx, 0.0) return # Logits row logits_row = logits_ptr + row_idx * V # Online softmax (max + sum) col_offsets = tl.arange(0, BLOCK_V) row_max = -float("inf") row_sum = 0.0 for v_block_start in range(0, V, BLOCK_V): block_offsets = v_block_start + col_offsets mask = block_offsets < V block_logits = tl.load(logits_row + block_offsets, mask=mask, other=-float("inf")) block_max = tl.max(block_logits, axis=0) new_max = tl.maximum(row_max, block_max) # Rescale running sum row_sum = row_sum * tl.exp(row_max - new_max) row_sum += tl.sum(tl.exp(block_logits - new_max), axis=0) row_max = new_max # Loss: logits[label] - log_sum_exp logit_label = tl.load(logits_row + label) log_z = tl.log(row_sum) + row_max loss = log_z - logit_label tl.store(loss_ptr + row_idx, loss) # Gradient: softmax(logits) - one_hot(label) / batch_size for v_block_start in range(0, V, BLOCK_V): block_offsets = v_block_start + col_offsets mask = block_offsets < V block_logits = tl.load(logits_row + block_offsets, mask=mask, other=-float("inf")) block_softmax = tl.exp(block_logits - log_z) # Subtract one-hot is_label = (block_offsets == label) grad = (block_softmax - tl.where(is_label, 1.0, 0.0)) / n tl.store(grad_ptr + row_idx * V + block_offsets, grad, mask=mask) # Python wrapperdef fused_ce(logits, labels, ignore_index=-100): n, V = logits.shape loss = torch.empty(n, device=logits.device) grad = torch.zeros_like(logits) BLOCK_V = 1024 fused_ce_kernel[(n,)]( logits, labels, loss, grad, n, V, IGNORE_INDEX=ignore_index, BLOCK_V=BLOCK_V, ) valid_mask = labels != ignore_index mean_loss = loss[valid_mask].mean() return mean_loss, grad # Bench (RTX 4090, batch=8, seq=4096, vocab=128K):# torch.nn.functional.cross_entropy: 12.4 ms# Custom Triton fused CE: 8.7 ms (-30%)Custom Fused Cross-Entropy Triton kernel
✅ Teslim
- Bu kernel'ı RTX 4090'ında çalıştır + torch CE ile karşılaştır. 2) Numeric bit-exact testle doğrula. 3) Sonraki ders: 13.4 — Liger Kernel Tour.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Part 0 — Engineering Foundations
Welcome to the Fine-Tuning Cookbook: System, Stage Taxonomy, and the Reproducibility Contract
Start LearningPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags, and Deterministic CUDA — End the 'Works on My Machine' Problem
Start LearningPart 0 — Engineering Foundations