Multi-Head Attention: N Parallel Heads, Concat + Projection, Grouped-Query Attention (GQA), Multi-Query Attention (MQA)
Why we split single attention into N parallel heads: each head's capacity to learn different patterns (syntactic, semantic, positional). Concat + output projection architecture, head pruning empirical findings, Llama-3 grouped-query attention (GQA), Mistral multi-query attention (MQA), head visualization with Turkish examples.
Şükrü Yusuf KAYA
70 min read
Advanced🧠 Tek dikkat değil, çoklu paralel dikkat — her head farklı pattern öğrenir
Tek attention çalışır ama tek bir 'angle'dan bakar. 'İstanbul'un başkenti' cümlesinde tek head'in 'başkenti' kelimesinin hangi token'a dikkat ettiğine karar vermesi sınırlı. Multi-head attention çözümü: aynı input N farklı projection'la N farklı attention compute. Her head bir başka pattern: bir tanesi syntactic dependency (özne-yüklem), diğeri semantic similarity, başkası positional pattern. Llama-3 32 head var, GPT-4 (tahmini) 96. 70 dakika sonra: multi-head matematiğini, head ne öğrendiğini empirik olarak, modern optimizasyon GQA/MQA — Llama-3 ve Mistral architecture'larında — derinlemesine kavramış olacaksın.
Ders Haritası (10 Bölüm)#
- Niye multi-head — tek head'in sınırı
- Matematik — split, parallel attention, concat
- Implementation — W_Q, W_K, W_V single big matrix
- Head pruning — empirical bulgular (Michel 2019, Voita 2019)
- Head specialization — syntactic, semantic, positional roles
- Memory cost — N head'in KV cache toplamı
- Multi-Query Attention (MQA) — Shazeer 2019, Mistral
- Grouped-Query Attention (GQA) — Ainslie 2023, Llama-3
- GQA/MQA empirical — quality vs efficiency trade-off
- Türkçe head visualization — pratik anlama
2. Multi-Head Matematik#
2.1 Split#
Input x shape: [seq, d_model]. d_model = N heads × d_head.
Llama-3-8B: d_model = 4096, N = 32 → d_head = 128.
Q, K, V = x @ W_Q, x @ W_K, x @ W_V # all [seq, d_model]
Reshape into heads:
Q = Q.view(seq, N, d_head) # [seq, N, d_head] K = K.view(seq, N, d_head) V = V.view(seq, N, d_head)
2.2 Parallel attention per head#
Her head independent attention compute:
for h in range(N): head_h = attention(Q[:, h, :], K[:, h, :], V[:, h, :]) # [seq, d_head]
Gerçekte batched einsum/matmul kullanır.
2.3 Concat + projection#
N head output'larını concat:
concat = head_0 || head_1 || ... || head_{N-1} # [seq, N × d_head] = [seq, d_model] output = concat @ W_O # [seq, d_model]
W_O learned projection [d_model, d_model]. 'Head'leri mix eder.
2.4 Total parameter#
MHA params: 4 × d_model² = 4 × 4096² = 67M (Llama-3-8B per layer).
32 layer × 67M = 2.1B params just for attention QKVO weights.
Daha detay:
- W_Q, W_K, W_V: each [d_model, d_model], 16.7M params
- W_O: [d_model, d_model], 16.7M params
- Total: 67M
2.5 Niye head split#
Mathematical insight: d_model dim'i N alt-uzaya böl. Her alt-uzayda farklı attention pattern.
Intuition: 'eşit budget'la tek büyük head yerine N küçük head'in toplam capacity'si benzer ama diversity avantajı.
python
import torchimport torch.nn as nnimport torch.nn.functional as Fimport math class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_head = d_model // n_heads # Single big matrices (more efficient than 4 separate) self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x, mask=None): batch, seq, _ = x.shape # Single matmul for Q, K, V qkv = self.qkv_proj(x) # [batch, seq, 3*d_model] qkv = qkv.view(batch, seq, 3, self.n_heads, self.d_head) qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, d_head] Q, K, V = qkv[0], qkv[1], qkv[2] # Scaled dot-product attention (per head, batched) scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_head) # [batch, heads, seq, seq] if mask is not None: scores = scores.masked_fill(mask, float('-inf')) weights = F.softmax(scores, dim=-1) out = weights @ V # [batch, heads, seq, d_head] # Concat heads out = out.transpose(1, 2).contiguous() # [batch, seq, heads, d_head] out = out.view(batch, seq, self.d_model) # Output projection out = self.out_proj(out) return out # Testd_model = 4096n_heads = 32mha = MultiHeadAttention(d_model, n_heads)print(f"Total params: {sum(p.numel() for p in mha.parameters()):,}") # ~67M x = torch.randn(1, 2048, d_model)out = mha(x)print(f"Output shape: {out.shape}") # [1, 2048, 4096]Multi-Head Attention — production-grade PyTorch
4-5. Head Pruning + Specialization#
4.1 Michel 2019: 'Are Sixteen Heads Really Better than One?'#
Michel, Levy, Neubig 2019 paper'ı. Empirical finding: BERT'te %60+ head prune edilebilir → minimal accuracy loss.
İmplications:
- Çoğu head redundant veya marjinal katkı
- Bazı head'ler kritik — onları korumak yeterli
- Production: head pruning efficient inference için pratik
4.2 Voita 2019: 'Analyzing Multi-Head Self-Attention'#
Voita et al. 2019 her head'in specialization çıkarımı:
Head types:
- Positional heads: relative position'a dikkat eder (önceki token, 2 önceki token, vs.)
- Syntactic heads: dependency tree pattern'leri (özne-yüklem, isim-sıfat)
- Rare-token heads: rare vocabulary'ye specialized
- Common heads: frequent words'e attention
4.3 Empirical Türkçe örnek#
BERT-base-Turkish probe (32K vocab):
- Layer 2 head 4: 'önceki token' pattern (positional)
- Layer 5 head 8: 'özne-yüklem' pattern
- Layer 9 head 11: 'pronoun-antecedent' pattern
- Layer 10 head 2: 'noun-modifier' (Türkçe sıfat-isim)
Çoğu head: 'noise' — gerçek pattern yok, training noise.
4.4 Voita findings: %80 head pruneable#
- %20 critical heads
- %60 partially useful
- %20 noise (zero contribution)
Production implication: distillation + head pruning ile 2-3x faster inference.
7-9. MQA + GQA — Modern Efficiency#
7.1 MQA (Multi-Query Attention, Shazeer 2019)#
Problem: KV cache memory büyük (32 head × 128 × seq = büyük).
MQA çözümü: N query head, tek shared K + V head.
Q: [seq, N, d_head] # N heads K: [seq, 1, d_head] # 1 shared head V: [seq, 1, d_head] # 1 shared head
KV cache: 32x daha küçük. Speed: 2-5x faster inference.
Quality cost: %1-2 perplexity (acceptable trade-off).
Kullanan modeller: Mistral-7B, Falcon, PaLM.
7.2 GQA (Grouped-Query Attention, Ainslie 2023)#
MQA çok agresif olabilir. GQA orta yol: G groups of query heads share same K, V.
N = 32 heads, G = 8 groups → her group 4 query head, 1 K/V head paylaşır Q: [seq, 32, d_head] # 32 query heads K: [seq, 8, d_head] # 8 key heads (paylaşımlı) V: [seq, 8, d_head] # 8 value heads
KV cache: 4x daha küçük (32/8 = 4).
Quality: %0.5 perplexity loss.
Llama-3-8B: 32 Q heads, 8 KV heads (G=8). Llama-3-70B: 64 Q, 8 KV (G=8).
7.3 Empirical comparison (Ainslie 2023)#
Llama-2 7B baseline:
- MHA (32 Q, 32 KV): baseline
- MQA (32 Q, 1 KV): -1.5% accuracy, 5x faster
- GQA (32 Q, 8 KV): -0.3% accuracy, 4x faster
GQA sweet spot: quality preservation + meaningful speedup.
7.4 KV cache memory savings#
Llama-3-8B 128K context per request:
- MHA: 2 × 32 layers × 32 KV heads × 128 × 128K × 2 byte = 132 GB
- GQA (G=8): 2 × 32 × 8 × 128 × 128K × 2 = 33 GB
- MQA: 2 × 32 × 1 × 128 × 128K × 2 = 4.1 GB
MQA enable çok daha çok concurrent request per GPU.
7.5 PyTorch implementation outline#
class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_q_heads, n_kv_heads): super().__init__() self.n_q_heads = n_q_heads self.n_kv_heads = n_kv_heads self.n_rep = n_q_heads // n_kv_heads # repetitions for sharing d_head = d_model // n_q_heads self.q_proj = nn.Linear(d_model, n_q_heads * d_head) self.k_proj = nn.Linear(d_model, n_kv_heads * d_head) self.v_proj = nn.Linear(d_model, n_kv_heads * d_head) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x): # Q split into n_q heads, K/V split into n_kv heads Q = self.q_proj(x).view(batch, seq, self.n_q_heads, d_head) K = self.k_proj(x).view(batch, seq, self.n_kv_heads, d_head) V = self.v_proj(x).view(batch, seq, self.n_kv_heads, d_head) # Repeat K, V to match Q heads K = K.repeat_interleave(self.n_rep, dim=2) V = V.repeat_interleave(self.n_rep, dim=2) # Standard scaled dot-product attention # ...
7.6 Production model preferences (2026)#
- Llama-3, Llama-3.1: GQA (G=8)
- Mistral, Mixtral: GQA (G=8)
- GPT-4o (tahmini): GQA veya MHA
- Claude 3 (tahmini): GQA
- DeepSeek-V3: MLA (Multi-head Latent Attention — yeni varyant)
GQA fiili modern standard.
✅ Ders 8.2 Özeti — Multi-Head + GQA/MQA
Multi-head attention: aynı input N farklı projection'la N farklı attention compute. Her head farklı pattern öğrenir (positional, syntactic, semantic). Head pruning literature: %60-80 head redundant. GQA (Grouped-Query Attention, Ainslie 2023): N query head, G key/value head paylaşır. KV cache memory dramatic azalır (4x), quality minimal kayıp (%0.5). Llama-3, Mistral, GPT-4o — modern standard. MQA daha agresif (1 KV head), %1-2 quality loss. DeepSeek MLA ile attention efficiency frontier devam ediyor. Ders 8.3'te attention pattern analysis ve interpretability'e geçeceğiz.
Sıradaki Ders: Attention Patterns + Interpretability#
Ders 8.3: induction heads (Anthropic findings), attention pattern types, BERT-vs-GPT attention farkı, attention visualization tooling, Türkçe attention analiz.
Frequently Asked Questions
Empirical: N = d_model / 64 or / 128. Llama-3-8B: 4096/128 = 32 heads. GPT-3: 12288/128 = 96. Trade-off: more heads = more diversity but each head 'shallower' (d_head small). Sweet spot: d_head = 64-128.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Module 0: Course Framework & Workshop Setup
Who Is an LLM Engineer? The AI Engineering Career Ladder from Junior to Staff
Start LearningModule 0: Course Framework & Workshop Setup
Course Philosophy: Why This Path, Why This Order — The Skeleton of an 8-Month Curriculum
Start LearningModule 0: Course Framework & Workshop Setup