Skip to content

Pratik Lab: 50 Satırda KV Cache Implementasyonu

Teorinin sonunda kod var. PyTorch ile minik bir transformer attention layer'ı yazacağız, KV cache'i açık/kapalı toggle edip 10× hızlanma göreceğiz. Bu lab, transformer'ı 'içeriden' anlamanın kestirme yolu.

Şükrü Yusuf KAYA
20 min read
Advanced

Lab #2: KV Cache'i Sıfırdan Yaz

Bu lab'in amacı: KV cache mekanizmasının gerçekten ne yaptığını kod yazarak içselleştirmek. Production'da bu kodu kullanmayacağız — bu öğrenme aracı.
Bitirince elinde:
  • 50 satır PyTorch'la minik bir multi-head attention layer
  • KV cache'i
    use_cache=True/False
    ile toggle edebilen bir generate fonksiyonu
  • Aynı output'u garantileyen, ama 10× hızlı çalışan bir versiyon
Lokal Kurulum Tavsiyesi
Bu lab tarayıcıda çalışmayabilir çünkü PyTorch CPU bile büyük. Lokal kuruluma geç, ya da Google Colab'ı kullan. Bu kod ~1-2 dakikada biter.

Adım 1 — Minimal Multi-Head Attention#

İlk önce attention'ı yaz, cache yok. Sonra adım adım cache'i ekleyeceğiz.
python
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class MiniAttention(nn.Module):
"""Minik causal multi-head attention. n_heads, d_head=8 sabit."""
 
def __init__(self, d_model=64, n_heads=4):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
 
def forward(self, x, kv_cache=None):
"""
x: (batch, seq, d_model)
kv_cache: None VEYA (K_prev, V_prev) tuple — (batch, n_heads, prev_seq, d_head)
Returns: (output, new_kv_cache)
"""
B, T, C = x.shape
H, D = self.n_heads, self.d_head
 
# Q, K, V projeksiyonu
qkv = self.qkv(x) # (B, T, 3C)
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, H, D).transpose(1, 2) # (B, H, T, D)
k = k.view(B, T, H, D).transpose(1, 2)
v = v.view(B, T, H, D).transpose(1, 2)
 
# ⭐ KV CACHE LOGIC ⭐
if kv_cache is not None:
k_prev, v_prev = kv_cache
k = torch.cat([k_prev, k], dim=2) # geçmiş K ile birleştir
v = torch.cat([v_prev, v], dim=2)
 
# Yeni cache: birleşmiş K, V
new_cache = (k, v)
 
# Attention skorları
att = (q @ k.transpose(-2, -1)) / (D ** 0.5)
# Causal mask: q sadece kendi pozisyonuna kadar bakabilir
q_len = q.shape[2]
k_len = k.shape[2]
mask = torch.tril(torch.ones(q_len, k_len)).bool()
# Eğer cache'ten geliyorsa, mask'i offset'le
if k_len > q_len:
mask = torch.cat(
[torch.ones(q_len, k_len - q_len).bool(), mask], dim=1
)
att = att.masked_fill(~mask, float("-inf"))
att = F.softmax(att, dim=-1)
 
# Output
out = att @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out(out), new_cache
MiniAttention — KV cache'li attention layer (50 satır)

Adım 2 — Generate Fonksiyonu: Cache OFF vs ON#

python
def generate_no_cache(attn, prompt, max_new_tokens=20):
"""Cache YOK — her step'te tüm sequence'i yeniden işle."""
seq = prompt.clone()
for _ in range(max_new_tokens):
out, _ = attn(seq, kv_cache=None) # cache yok, tüm seq baştan
next_logit = out[:, -1:, :]
next_tok = next_logit.argmax(dim=-1, keepdim=False)
# Fake "vocab → embedding" — basitlik için son output'u doğrudan ekle
seq = torch.cat([seq, next_logit], dim=1)
return seq
 
def generate_with_cache(attn, prompt, max_new_tokens=20):
"""Cache AÇIK — sadece yeni tokenı işle."""
# Adım 1: prompt'u prefill et, cache oluştur
out, cache = attn(prompt, kv_cache=None)
 
seq = prompt.clone()
next_input = out[:, -1:, :] # son token
 
for _ in range(max_new_tokens):
out, cache = attn(next_input, kv_cache=cache) # ⭐ sadece 1 token
next_input = out[:, -1:, :]
seq = torch.cat([seq, next_input], dim=1)
return seq
İki versiyonu yan yana koy — cache'siz vs cache'li
python
import torch
import time
 
# Yukarıdaki MiniAttention sınıfını burada da var sayalım
# (Tam kod yukarıda)
 
torch.manual_seed(42)
attn = MiniAttention(d_model=64, n_heads=4).eval()
 
# 100 tokenlı bir prompt simüle et
prompt = torch.randn(1, 100, 64)
 
# Cache OFF
start = time.perf_counter()
with torch.no_grad():
_ = generate_no_cache(attn, prompt, max_new_tokens=50)
t_off = time.perf_counter() - start
 
# Cache ON
start = time.perf_counter()
with torch.no_grad():
_ = generate_with_cache(attn, prompt, max_new_tokens=50)
t_on = time.perf_counter() - start
 
print(f"Cache OFF: {t_off*1000:.1f} ms")
print(f"Cache ON: {t_on*1000:.1f} ms")
print(f"Hızlanma: {t_off/t_on:.1f}×")
Cache OFF vs ON: hız karşılaştırması
Sayı Gerçek
~10× hızlanma — sadece 50 satırlık bir attention layer'ında. Production transformer'larda (40+ layer) bu kazanç çok daha büyük çünkü her layer için KV cache var.

Adım 3 — Output Eşdeğerliğini Doğrula#

Ders 10'da kanıtladığımız "cache lossless" iddiasını test edelim:
python
import torch
 
torch.manual_seed(42)
attn = MiniAttention(d_model=64, n_heads=4).eval()
 
prompt = torch.randn(1, 10, 64)
 
with torch.no_grad():
out_no_cache = generate_no_cache(attn, prompt, max_new_tokens=5)
out_with_cache = generate_with_cache(attn, prompt, max_new_tokens=5)
 
# Bit-identik mi?
diff = (out_no_cache - out_with_cache).abs().max().item()
print(f"Max difference: {diff:.2e}")
print(f"Are they identical? {torch.allclose(out_no_cache, out_with_cache, atol=1e-5)}")
Cache açık vs kapalı: output identik olmalı
Fark < 5×10⁻⁷ → floating point noise. Matematiksel olarak özdeş, sadece floating point precision kaybı (transpose order gibi). Yani caching gerçekten lossless.

Çıkarılacak Dersler#

Bu lab'i yaptıysan, şu kavramlar artık parmaklarında:
  1. KV cache, attention layer'ın
    forward
    metodunda 2 satırla eklenir
    (
    torch.cat
    )
  2. Generate döngüsü iki faza ayrılır: prefill (full sequence) + decode (token by token)
  3. Cache'li versiyon ~10× hızlı, output identik
  4. Causal mask'in offset'i — eğer cache'ten geliyorsan, mask'i kaydırman lazım
Devamı Modül 10'da
Modül 10'da vLLM ve SGLang'ın bu konsepti multi-tenant ortamda nasıl uyguladığını göreceğiz — PagedAttention, RadixAttention, prefix tree. Bu lab'in mantığı orada da geçerli.

✓ Pekiştir#

Bir Sonraki Derste#

vLLM ve SGLang'ın PagedAttention trick'ini sezgisel olarak anlayacağız. Multi-tenant production'da KV cache'i nasıl verimli yönetiyorlar?

Frequently Asked Questions

Hayır, bu öğrenme amaçlı simplified bir versiyon. Production için vLLM, SGLang, TensorRT-LLM gibi optimize edilmiş kütüphaneler kullan. Bu lab'i 'transformer'ı içeriden anladım' kanıtı say.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content