Pratik Lab: 50 Satırda KV Cache Implementasyonu
Teorinin sonunda kod var. PyTorch ile minik bir transformer attention layer'ı yazacağız, KV cache'i açık/kapalı toggle edip 10× hızlanma göreceğiz. Bu lab, transformer'ı 'içeriden' anlamanın kestirme yolu.
Şükrü Yusuf KAYA
20 min read
AdvancedLab #2: KV Cache'i Sıfırdan Yaz
Bu lab'in amacı: KV cache mekanizmasının gerçekten ne yaptığını kod yazarak içselleştirmek. Production'da bu kodu kullanmayacağız — bu öğrenme aracı.
Bitirince elinde:
- 50 satır PyTorch'la minik bir multi-head attention layer
- KV cache'i ile toggle edebilen bir generate fonksiyonu
use_cache=True/False - Aynı output'u garantileyen, ama 10× hızlı çalışan bir versiyon
Lokal Kurulum Tavsiyesi
Bu lab tarayıcıda çalışmayabilir çünkü PyTorch CPU bile büyük. Lokal kuruluma geç, ya da Google Colab'ı kullan. Bu kod ~1-2 dakikada biter.
Adım 1 — Minimal Multi-Head Attention#
İlk önce attention'ı yaz, cache yok. Sonra adım adım cache'i ekleyeceğiz.
python
import torchimport torch.nn as nnimport torch.nn.functional as F class MiniAttention(nn.Module): """Minik causal multi-head attention. n_heads, d_head=8 sabit.""" def __init__(self, d_model=64, n_heads=4): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model) self.out = nn.Linear(d_model, d_model) def forward(self, x, kv_cache=None): """ x: (batch, seq, d_model) kv_cache: None VEYA (K_prev, V_prev) tuple — (batch, n_heads, prev_seq, d_head) Returns: (output, new_kv_cache) """ B, T, C = x.shape H, D = self.n_heads, self.d_head # Q, K, V projeksiyonu qkv = self.qkv(x) # (B, T, 3C) q, k, v = qkv.split(C, dim=-1) q = q.view(B, T, H, D).transpose(1, 2) # (B, H, T, D) k = k.view(B, T, H, D).transpose(1, 2) v = v.view(B, T, H, D).transpose(1, 2) # ⭐ KV CACHE LOGIC ⭐ if kv_cache is not None: k_prev, v_prev = kv_cache k = torch.cat([k_prev, k], dim=2) # geçmiş K ile birleştir v = torch.cat([v_prev, v], dim=2) # Yeni cache: birleşmiş K, V new_cache = (k, v) # Attention skorları att = (q @ k.transpose(-2, -1)) / (D ** 0.5) # Causal mask: q sadece kendi pozisyonuna kadar bakabilir q_len = q.shape[2] k_len = k.shape[2] mask = torch.tril(torch.ones(q_len, k_len)).bool() # Eğer cache'ten geliyorsa, mask'i offset'le if k_len > q_len: mask = torch.cat( [torch.ones(q_len, k_len - q_len).bool(), mask], dim=1 ) att = att.masked_fill(~mask, float("-inf")) att = F.softmax(att, dim=-1) # Output out = att @ v out = out.transpose(1, 2).contiguous().view(B, T, C) return self.out(out), new_cacheMiniAttention — KV cache'li attention layer (50 satır)
Adım 2 — Generate Fonksiyonu: Cache OFF vs ON#
python
def generate_no_cache(attn, prompt, max_new_tokens=20): """Cache YOK — her step'te tüm sequence'i yeniden işle.""" seq = prompt.clone() for _ in range(max_new_tokens): out, _ = attn(seq, kv_cache=None) # cache yok, tüm seq baştan next_logit = out[:, -1:, :] next_tok = next_logit.argmax(dim=-1, keepdim=False) # Fake "vocab → embedding" — basitlik için son output'u doğrudan ekle seq = torch.cat([seq, next_logit], dim=1) return seq def generate_with_cache(attn, prompt, max_new_tokens=20): """Cache AÇIK — sadece yeni tokenı işle.""" # Adım 1: prompt'u prefill et, cache oluştur out, cache = attn(prompt, kv_cache=None) seq = prompt.clone() next_input = out[:, -1:, :] # son token for _ in range(max_new_tokens): out, cache = attn(next_input, kv_cache=cache) # ⭐ sadece 1 token next_input = out[:, -1:, :] seq = torch.cat([seq, next_input], dim=1) return seqİki versiyonu yan yana koy — cache'siz vs cache'li
python
import torchimport time # Yukarıdaki MiniAttention sınıfını burada da var sayalım# (Tam kod yukarıda) torch.manual_seed(42)attn = MiniAttention(d_model=64, n_heads=4).eval() # 100 tokenlı bir prompt simüle etprompt = torch.randn(1, 100, 64) # Cache OFFstart = time.perf_counter()with torch.no_grad(): _ = generate_no_cache(attn, prompt, max_new_tokens=50)t_off = time.perf_counter() - start # Cache ONstart = time.perf_counter()with torch.no_grad(): _ = generate_with_cache(attn, prompt, max_new_tokens=50)t_on = time.perf_counter() - start print(f"Cache OFF: {t_off*1000:.1f} ms")print(f"Cache ON: {t_on*1000:.1f} ms")print(f"Hızlanma: {t_off/t_on:.1f}×")Cache OFF vs ON: hız karşılaştırması
Sayı Gerçek
~10× hızlanma — sadece 50 satırlık bir attention layer'ında. Production transformer'larda (40+ layer) bu kazanç çok daha büyük çünkü her layer için KV cache var.
Adım 3 — Output Eşdeğerliğini Doğrula#
Ders 10'da kanıtladığımız "cache lossless" iddiasını test edelim:
python
import torch torch.manual_seed(42)attn = MiniAttention(d_model=64, n_heads=4).eval() prompt = torch.randn(1, 10, 64) with torch.no_grad(): out_no_cache = generate_no_cache(attn, prompt, max_new_tokens=5) out_with_cache = generate_with_cache(attn, prompt, max_new_tokens=5) # Bit-identik mi?diff = (out_no_cache - out_with_cache).abs().max().item()print(f"Max difference: {diff:.2e}")print(f"Are they identical? {torch.allclose(out_no_cache, out_with_cache, atol=1e-5)}")Cache açık vs kapalı: output identik olmalı
Fark < 5×10⁻⁷ → floating point noise. Matematiksel olarak özdeş, sadece floating point precision kaybı (transpose order gibi). Yani caching gerçekten lossless.
Çıkarılacak Dersler#
Bu lab'i yaptıysan, şu kavramlar artık parmaklarında:
- KV cache, attention layer'ın metodunda 2 satırla eklenir (
forward)torch.cat - Generate döngüsü iki faza ayrılır: prefill (full sequence) + decode (token by token)
- Cache'li versiyon ~10× hızlı, output identik
- Causal mask'in offset'i — eğer cache'ten geliyorsan, mask'i kaydırman lazım
Devamı Modül 10'da
Modül 10'da vLLM ve SGLang'ın bu konsepti multi-tenant ortamda nasıl uyguladığını göreceğiz — PagedAttention, RadixAttention, prefix tree. Bu lab'in mantığı orada da geçerli.
✓ Pekiştir#
Bir Sonraki Derste#
vLLM ve SGLang'ın PagedAttention trick'ini sezgisel olarak anlayacağız. Multi-tenant production'da KV cache'i nasıl verimli yönetiyorlar?
Frequently Asked Questions
Hayır, bu öğrenme amaçlı simplified bir versiyon. Production için vLLM, SGLang, TensorRT-LLM gibi optimize edilmiş kütüphaneler kullan. Bu lab'i 'transformer'ı içeriden anladım' kanıtı say.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
1. Temeller — Context Penceresi Ekonomisi
Bu Eğitim Hakkında ve Prompt Caching Neden Önemli?
Start Learning1. Temeller — Context Penceresi Ekonomisi
Token Ekonomisi 101: Input vs Output Cost Asimetrisi
Start Learning1. Temeller — Context Penceresi Ekonomisi