İçeriğe geç

Multi-Head Attention: N Paralel Head, Concat + Projection, Grouped-Query Attention (GQA), Multi-Query Attention (MQA)

Tek attention'ı niye N paralel head'e bölüyoruz: her head'in farklı pattern öğrenme kapasitesi (syntactic, semantic, positional). Concat + output projection mimari, head pruning empirical bulgular, Llama-3 grouped-query attention (GQA), Mistral multi-query attention (MQA), head visualization Türkçe örneklerle.

Şükrü Yusuf KAYA
70 dakikalık okuma
İleri
Multi-Head Attention: N Paralel Head, Concat + Projection, Grouped-Query Attention (GQA), Multi-Query Attention (MQA)
🧠 Tek dikkat değil, çoklu paralel dikkat — her head farklı pattern öğrenir
Tek attention çalışır ama tek bir 'angle'dan bakar. 'İstanbul'un başkenti' cümlesinde tek head'in 'başkenti' kelimesinin hangi token'a dikkat ettiğine karar vermesi sınırlı. Multi-head attention çözümü: aynı input N farklı projection'la N farklı attention compute. Her head bir başka pattern: bir tanesi syntactic dependency (özne-yüklem), diğeri semantic similarity, başkası positional pattern. Llama-3 32 head var, GPT-4 (tahmini) 96. 70 dakika sonra: multi-head matematiğini, head ne öğrendiğini empirik olarak, modern optimizasyon GQA/MQA — Llama-3 ve Mistral architecture'larında — derinlemesine kavramış olacaksın.

Ders Haritası (10 Bölüm)#

  1. Niye multi-head — tek head'in sınırı
  2. Matematik — split, parallel attention, concat
  3. Implementation — W_Q, W_K, W_V single big matrix
  4. Head pruning — empirical bulgular (Michel 2019, Voita 2019)
  5. Head specialization — syntactic, semantic, positional roles
  6. Memory cost — N head'in KV cache toplamı
  7. Multi-Query Attention (MQA) — Shazeer 2019, Mistral
  8. Grouped-Query Attention (GQA) — Ainslie 2023, Llama-3
  9. GQA/MQA empirical — quality vs efficiency trade-off
  10. Türkçe head visualization — pratik anlama

2. Multi-Head Matematik#

2.1 Split#

Input x shape: [seq, d_model]. d_model = N heads × d_head.
Llama-3-8B: d_model = 4096, N = 32 → d_head = 128.
Q, K, V = x @ W_Q, x @ W_K, x @ W_V # all [seq, d_model]
Reshape into heads:
Q = Q.view(seq, N, d_head) # [seq, N, d_head] K = K.view(seq, N, d_head) V = V.view(seq, N, d_head)

2.2 Parallel attention per head#

Her head independent attention compute:
for h in range(N): head_h = attention(Q[:, h, :], K[:, h, :], V[:, h, :]) # [seq, d_head]
Gerçekte batched einsum/matmul kullanır.

2.3 Concat + projection#

N head output'larını concat:
concat = head_0 || head_1 || ... || head_{N-1} # [seq, N × d_head] = [seq, d_model] output = concat @ W_O # [seq, d_model]
W_O learned projection [d_model, d_model]. 'Head'leri mix eder.

2.4 Total parameter#

MHA params: 4 × d_model² = 4 × 4096² = 67M (Llama-3-8B per layer). 32 layer × 67M = 2.1B params just for attention QKVO weights.
Daha detay:
  • W_Q, W_K, W_V: each [d_model, d_model], 16.7M params
  • W_O: [d_model, d_model], 16.7M params
  • Total: 67M

2.5 Niye head split#

Mathematical insight: d_model dim'i N alt-uzaya böl. Her alt-uzayda farklı attention pattern.
Intuition: 'eşit budget'la tek büyük head yerine N küçük head'in toplam capacity'si benzer ama diversity avantajı.
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Single big matrices (more efficient than 4 separate)
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
batch, seq, _ = x.shape
# Single matmul for Q, K, V
qkv = self.qkv_proj(x) # [batch, seq, 3*d_model]
qkv = qkv.view(batch, seq, 3, self.n_heads, self.d_head)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, d_head]
Q, K, V = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention (per head, batched)
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_head) # [batch, heads, seq, seq]
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
weights = F.softmax(scores, dim=-1)
out = weights @ V # [batch, heads, seq, d_head]
# Concat heads
out = out.transpose(1, 2).contiguous() # [batch, seq, heads, d_head]
out = out.view(batch, seq, self.d_model)
# Output projection
out = self.out_proj(out)
return out
 
 
# Test
d_model = 4096
n_heads = 32
mha = MultiHeadAttention(d_model, n_heads)
print(f"Total params: {sum(p.numel() for p in mha.parameters()):,}") # ~67M
 
x = torch.randn(1, 2048, d_model)
out = mha(x)
print(f"Output shape: {out.shape}") # [1, 2048, 4096]
Multi-Head Attention — production-grade PyTorch

4-5. Head Pruning + Specialization#

4.1 Michel 2019: 'Are Sixteen Heads Really Better than One?'#

Michel, Levy, Neubig 2019 paper'ı. Empirical finding: BERT'te %60+ head prune edilebilir → minimal accuracy loss.
İmplications:
  • Çoğu head redundant veya marjinal katkı
  • Bazı head'ler kritik — onları korumak yeterli
  • Production: head pruning efficient inference için pratik

4.2 Voita 2019: 'Analyzing Multi-Head Self-Attention'#

Voita et al. 2019 her head'in specialization çıkarımı:
Head types:
  • Positional heads: relative position'a dikkat eder (önceki token, 2 önceki token, vs.)
  • Syntactic heads: dependency tree pattern'leri (özne-yüklem, isim-sıfat)
  • Rare-token heads: rare vocabulary'ye specialized
  • Common heads: frequent words'e attention

4.3 Empirical Türkçe örnek#

BERT-base-Turkish probe (32K vocab):
  • Layer 2 head 4: 'önceki token' pattern (positional)
  • Layer 5 head 8: 'özne-yüklem' pattern
  • Layer 9 head 11: 'pronoun-antecedent' pattern
  • Layer 10 head 2: 'noun-modifier' (Türkçe sıfat-isim)
Çoğu head: 'noise' — gerçek pattern yok, training noise.

4.4 Voita findings: %80 head pruneable#

  • %20 critical heads
  • %60 partially useful
  • %20 noise (zero contribution)
Production implication: distillation + head pruning ile 2-3x faster inference.

7-9. MQA + GQA — Modern Efficiency#

7.1 MQA (Multi-Query Attention, Shazeer 2019)#

Problem: KV cache memory büyük (32 head × 128 × seq = büyük).
MQA çözümü: N query head, tek shared K + V head.
Q: [seq, N, d_head] # N heads K: [seq, 1, d_head] # 1 shared head V: [seq, 1, d_head] # 1 shared head
KV cache: 32x daha küçük. Speed: 2-5x faster inference.
Quality cost: %1-2 perplexity (acceptable trade-off).
Kullanan modeller: Mistral-7B, Falcon, PaLM.

7.2 GQA (Grouped-Query Attention, Ainslie 2023)#

MQA çok agresif olabilir. GQA orta yol: G groups of query heads share same K, V.
N = 32 heads, G = 8 groups → her group 4 query head, 1 K/V head paylaşır Q: [seq, 32, d_head] # 32 query heads K: [seq, 8, d_head] # 8 key heads (paylaşımlı) V: [seq, 8, d_head] # 8 value heads
KV cache: 4x daha küçük (32/8 = 4). Quality: %0.5 perplexity loss.
Llama-3-8B: 32 Q heads, 8 KV heads (G=8). Llama-3-70B: 64 Q, 8 KV (G=8).

7.3 Empirical comparison (Ainslie 2023)#

Llama-2 7B baseline:
  • MHA (32 Q, 32 KV): baseline
  • MQA (32 Q, 1 KV): -1.5% accuracy, 5x faster
  • GQA (32 Q, 8 KV): -0.3% accuracy, 4x faster
GQA sweet spot: quality preservation + meaningful speedup.

7.4 KV cache memory savings#

Llama-3-8B 128K context per request:
  • MHA: 2 × 32 layers × 32 KV heads × 128 × 128K × 2 byte = 132 GB
  • GQA (G=8): 2 × 32 × 8 × 128 × 128K × 2 = 33 GB
  • MQA: 2 × 32 × 1 × 128 × 128K × 2 = 4.1 GB
MQA enable çok daha çok concurrent request per GPU.

7.5 PyTorch implementation outline#

class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_q_heads, n_kv_heads): super().__init__() self.n_q_heads = n_q_heads self.n_kv_heads = n_kv_heads self.n_rep = n_q_heads // n_kv_heads # repetitions for sharing d_head = d_model // n_q_heads self.q_proj = nn.Linear(d_model, n_q_heads * d_head) self.k_proj = nn.Linear(d_model, n_kv_heads * d_head) self.v_proj = nn.Linear(d_model, n_kv_heads * d_head) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x): # Q split into n_q heads, K/V split into n_kv heads Q = self.q_proj(x).view(batch, seq, self.n_q_heads, d_head) K = self.k_proj(x).view(batch, seq, self.n_kv_heads, d_head) V = self.v_proj(x).view(batch, seq, self.n_kv_heads, d_head) # Repeat K, V to match Q heads K = K.repeat_interleave(self.n_rep, dim=2) V = V.repeat_interleave(self.n_rep, dim=2) # Standard scaled dot-product attention # ...

7.6 Production model preferences (2026)#

  • Llama-3, Llama-3.1: GQA (G=8)
  • Mistral, Mixtral: GQA (G=8)
  • GPT-4o (tahmini): GQA veya MHA
  • Claude 3 (tahmini): GQA
  • DeepSeek-V3: MLA (Multi-head Latent Attention — yeni varyant)
GQA fiili modern standard.
✅ Ders 8.2 Özeti — Multi-Head + GQA/MQA
Multi-head attention: aynı input N farklı projection'la N farklı attention compute. Her head farklı pattern öğrenir (positional, syntactic, semantic). Head pruning literature: %60-80 head redundant. GQA (Grouped-Query Attention, Ainslie 2023): N query head, G key/value head paylaşır. KV cache memory dramatic azalır (4x), quality minimal kayıp (%0.5). Llama-3, Mistral, GPT-4o — modern standard. MQA daha agresif (1 KV head), %1-2 quality loss. DeepSeek MLA ile attention efficiency frontier devam ediyor. Ders 8.3'te attention pattern analysis ve interpretability'e geçeceğiz.

Sıradaki Ders: Attention Patterns + Interpretability#

Ders 8.3: induction heads (Anthropic findings), attention pattern types, BERT-vs-GPT attention farkı, attention visualization tooling, Türkçe attention analiz.

Sık Sorulan Sorular

Empirical: N = d_model / 64 veya / 128. Llama-3-8B: 4096/128 = 32 heads. GPT-3: 12288/128 = 96. Trade-off: more heads = more diversity ama her head daha 'sığ' (d_head küçük). Sweet spot: d_head = 64-128.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler