Skip to content

MoE Mathematical Anatomy: Gating Network, Top-k Routing, Load Balancing — Sparse Activation from Scratch

Internal mathematics of MoE: derivation of gating network, top-k routing implementation, expert collapse problem and load balancing loss (Shazeer 2017), auxiliary loss math, capacity factor, drop tokens, FLOP analysis. PyTorch MoE FFN layer implementation from scratch. Expert utilization observations on Turkish data.

Şükrü Yusuf KAYA
85 min read
Advanced
MoE Matematik Anatomi: Gating Network, Top-k Routing, Load Balancing — Sparse Activation Sıfırdan
🔬 MoE — Üç Matematik Probleminin Zarif Çözümü
MoE basit görünüyor: 'Birden fazla expert var, gating onları seçer.' Ama altında 3 zor matematik problemi yatıyor:
(1) Gating nasıl differentiable olur? Sparse top-k seçim discrete. Backprop nasıl?
(2) Load balancing nasıl? Eğer bütün token'lar aynı expert'i seçerse, sistem bozuluyor. 256 expert'ten 1 tanesi tüm yükü taşıyor.
(3) Memory + compute trade-off nasıl yönetilir? Capacity factor, drop tokens, expert paralleli.
Shazeer 2017'den DeepSeek-V3 2024'e kadar her major MoE paper'ı bu 3 problem üzerinde inceleme yapıyor. Bu ders matematiği sıfırdan kuruyoruz, PyTorch'ta implementation yapıyoruz, Türkçe data üzerinde expert utilization gözlemliyoruz. 85 dakika sonra: MoE'yi sadece kullanmayacak, debug edebilecek + optimize edebilecek mühendis olacaksın.

Bu Derste Neler Var? (13 Bölüm)#

  1. MoE FFN layer matematiği — formal tanım
  2. Gating network — softmax + top-k
  3. Top-k routing differentiability — Gumbel trick
  4. Expert collapse problemi
  5. Load balancing loss (Shazeer 2017)
  6. Auxiliary-loss-free load balancing (DeepSeek 2024)
  7. Capacity factor — bütçe sınırı
  8. Drop tokens stratejisi
  9. FLOP analizi — gerçek hesap maliyeti
  10. Memory analizi — niye MoE memory-hungry
  11. PyTorch sıfırdan MoE FFN layer
  12. Türkçe expert utilization deneyi
  13. Egzersizler

1-4. MoE FFN Layer Matematiği#

1.1 Klasik FFN (Modül 10)#

input x → FFN: y = W_2 · ReLU(W_1 · x + b_1) + b_2
Tek FFN, tüm parametreler her token için aktif.

1.2 MoE FFN (Shazeer 2017)#

input x → Gating: g(x) = softmax(W_g · x) ∈ ℝ^N Gating top-k seçim: TopK(g(x), k) ⊂ {1, ..., N} For each i in TopK: expert_i_output = FFN_i(x) # her expert kendi W_1_i, W_2_i Final output: y = Σ_{i ∈ TopK} g(x)_i × expert_i_output
N expert (genelde 8-256), her biri ayrı FFN. Top-k (genelde 1 veya 2 veya 8) aktif.

2.1 Gating network detayı#

Gating network basit bir linear layer + softmax:
import torch import torch.nn as nn import torch.nn.functional as F class Gating(nn.Module): def __init__(self, d_model, n_experts): super().__init__() self.W_g = nn.Linear(d_model, n_experts, bias=False) def forward(self, x): # x: [batch, seq, d_model] logits = self.W_g(x) # [batch, seq, n_experts] probs = F.softmax(logits, dim=-1) return probs

3.1 Top-k routing — differentiable mi?#

Problem: top_k discrete operation, ∂(top_k) tanımlı değil.
Çözüm: 'straight-through estimator'. Forward'da hard top-k, backward'da soft softmax gradient.
Daha temiz: Gumbel-Softmax trick (Maddison 2017, Jang 2017). Discrete decision'ları noisy softmax ile yaklaş, gradient flow.
Modern pratik: Top-2 routing with auxiliary loss. Gumbel kullanmıyorlar, sadece top-2 + load balancing loss.

4.1 Expert Collapse Problemi#

Naive MoE training'in en yaygın hatası: gating, tüm tokenları aynı expert'e yönlendirmeye başlar.
Neden? Random init'te bir expert biraz daha iyi gradient alır → daha çok seçilir → daha çok gradient → exponentially büyür. Diğer expert'lar 'açlıktan ölür' (no gradient).
Empirik: 64 expert MoE, naif training → 60 expert hiç kullanılmıyor, 4 expert tüm yükü taşıyor. 'Mode collapse'.
Çözüm gerekli.

5-8. Load Balancing + Capacity Factor#

5.1 Shazeer 2017 — Load Balancing Loss#

Auxiliary loss ekle: gating'i 'tüm expert'ları eşit kullan'mak için zorla.
Matematik:
L_aux = N · Σ_i f_i · p_i
Nerede:
  • f_i = expert_i'nin gerçek seçilme oranı (batch'te)
  • p_i = gating'in beklenti olasılığı (batch ortalaması)
İdeal durumda: f_i = p_i = 1/N (uniform).
Loss minimize: f ve p uniformly dağılmaya zorlanıyor.
Toplam loss:
L_total = L_task + α × L_aux
α = 0.01 typical.

6.1 DeepSeek-V3 — Auxiliary-loss-free#

2024 yeniliği: 'Auxiliary loss aslında task performance'ı bozuyor (modeli homogenize etmeye zorluyor)'.
Çözüm: bias terms ekle gating'e:
g(x) = softmax(W_g · x + b)
b_i sürekli güncellenir: under-used expert'in b'si artırılır, over-used'unki azaltılır. Auxiliary loss yok.
Result: %5-10 better task performance, load balance hâlâ iyi.

7.1 Capacity Factor#

Eşit dağılım garantisi olmadığında bile, bazı expert'lar dolup taşabilir. Memory koruması için capacity factor:
capacity_per_expert = (total_tokens / n_experts) × capacity_factor
Örnek: 100 token, 8 expert, capacity_factor=1.25. Her expert max 100/8 × 1.25 = ~16 token alabilir.
Kapasite dolarsa: extra token'lar drop edilir (next layer'a olduğu gibi geçer, expert bypass).

8.1 Drop tokens trade-off#

Capacity_factor düşükse (1.0): %5-10 token drop → bilgi kaybı Capacity_factor yüksekse (2.0): tüm token'lar geçer ama memory 2× → pahalı
Modern norm: 1.25-1.5 sweet spot.
DeepSeek-V3: capacity_factor=1.5, drop oranı %0.3 (negligible).
python
# MoE FFN Layer — sıfırdan PyTorch implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class MoEFFN(nn.Module):
def __init__(self, d_model, d_ff, n_experts=8, top_k=2, capacity_factor=1.25):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
self.d_model = d_model
# Gating
self.gating = nn.Linear(d_model, n_experts, bias=False)
# N expert (each = FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.SiLU(),
nn.Linear(d_ff, d_model),
)
for _ in range(n_experts)
])
def forward(self, x):
# x: [batch, seq, d_model]
B, L, D = x.shape
x_flat = x.view(B * L, D) # [B*L, D]
# 1. Gating
gate_logits = self.gating(x_flat) # [B*L, n_experts]
gate_probs = F.softmax(gate_logits, dim=-1)
# 2. Top-k selection
top_k_probs, top_k_indices = gate_probs.topk(self.top_k, dim=-1)
# top_k_probs: [B*L, top_k]
# top_k_indices: [B*L, top_k]
# Renormalize
top_k_probs = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-9)
# 3. Capacity
capacity = int(B * L * self.top_k / self.n_experts * self.capacity_factor)
# 4. Expert dispatch + processing
output = torch.zeros_like(x_flat)
for expert_id in range(self.n_experts):
# Hangi token'lar bu expert'e gidiyor?
mask = (top_k_indices == expert_id).any(dim=-1)
if mask.sum() == 0:
continue
# Token'lar
expert_tokens = x_flat[mask] # [n_assigned, D]
# Capacity check
if expert_tokens.size(0) > capacity:
expert_tokens = expert_tokens[:capacity] # drop overflow
# Expert forward
expert_output = self.experts[expert_id](expert_tokens)
# Gating weight
weight_mask = (top_k_indices == expert_id)
weights = (top_k_probs * weight_mask.float()).sum(dim=-1)
weights = weights[mask][:expert_output.size(0)]
# Add to output
output[mask][:expert_output.size(0)] += weights.unsqueeze(-1) * expert_output
return output.view(B, L, D)
def load_balancing_loss(self, x):
"""Shazeer 2017 auxiliary loss."""
B, L, D = x.shape
x_flat = x.view(B * L, D)
gate_logits = self.gating(x_flat)
gate_probs = F.softmax(gate_logits, dim=-1)
# f_i: empirical fraction of tokens routed to expert i
_, top_k_indices = gate_probs.topk(self.top_k, dim=-1)
one_hot = F.one_hot(top_k_indices.view(-1), self.n_experts).float()
f = one_hot.mean(dim=0)
# p_i: mean probability of expert i
p = gate_probs.mean(dim=0)
# L_aux = N · Σ f_i · p_i
loss = self.n_experts * (f * p).sum()
return loss
 
# Kullanım
moe = MoEFFN(d_model=4096, d_ff=14336, n_experts=8, top_k=2)
# (Mixtral 8x7B config benzeri)
 
x = torch.randn(2, 512, 4096)
out = moe(x)
print(f'Output shape: {out.shape}') # [2, 512, 4096]
 
aux_loss = moe.load_balancing_loss(x)
print(f'Auxiliary loss: {aux_loss.item():.4f}')
 
MoE FFN — Sıfırdan PyTorch Implementation

9-10. FLOP ve Memory Analizi#

9.1 Dense vs MoE FLOP#

Dense FFN (Llama-3-8B per layer):
  • d_model=4096, d_ff=14336
  • FLOP per token: 2 × 4096 × 14336 = 117M
MoE FFN (Mixtral 8x7B):
  • d_model=4096, d_ff=14336, n_experts=8, top_k=2
  • FLOP per token: 2 × (2 × 4096 × 14336) = 234M (top-2 active)
  • Per-expert FLOP aynı, ama top-2 olduğu için 2× compute
Mixtral 8×7B = 47B total parameter, ama active per token ≈ 13B parametre (47B'nin %28'i).

9.2 DeepSeek-V3 FLOP#

  • 671B total, 37B active per token
  • 37B / 671B = %5.5 active ratio
  • Aktif compute: 37B dense model'inkiyle aynı
  • Quality: 671B dense ile karşılaştırılabilir (capacity)
Sweet spot: kompleksin %5-30 active ratio. DeepSeek-V3 (%5.5) Mixtral'den daha agresif sparse.

10.1 Memory analizi#

MoE'nin gerçek sınırı: memory.
Mixtral 8x7B (47B param, bf16): 94 GB.
  • 1× H100 80GB: yetmez!
  • 2× H100 NVLink: yeterli
  • 1× H100 + AWQ 4-bit: yeterli (~25 GB)
DeepSeek-V3 (671B param, bf16): 1.3 TB.
  • 8× H100 80GB: yeterli (FSDP veya expert parallel)
  • 4× H200 141GB: yeterli
  • 4-bit quantize: ~340 GB, 4× H100 mümkün

10.2 Expert parallelism (MoE-specific)#

MoE memory dağıtımı için yeni paralelleşme tipi: expert parallelism.
  • Expert_1, Expert_2, ..., Expert_N farklı GPU'larda
  • Token routing GPU'lar arası all-to-all communication (network heavy)
  • DDP veya FSDP ile birleştirilebilir
DeepSeek-V3 production: expert parallelism + pipeline parallelism + data parallelism (3D).

11.1 Türkçe Expert Utilization Deneyi#

Mixtral 8x7B'de Türkçe input için expert utilization gözle (HuggingFace hook):
from transformers import AutoModel, AutoTokenizer import torch model = AutoModel.from_pretrained('mistralai/Mixtral-8x7B-Instruct-v0.1', torch_dtype=torch.bfloat16, device_map='auto') tokenizer = AutoTokenizer.from_pretrained('mistralai/Mixtral-8x7B-Instruct-v0.1') # Hook MoE gating activations = {} def hook(name): def fn(module, inp, out): # out = (final_hidden_state, router_logits) router_logits = out[1] activations[name] = router_logits return fn for i, layer in enumerate(model.model.layers): layer.block_sparse_moe.register_forward_hook(hook(f'layer_{i}')) # Türkçe metni encode text = 'İstanbul Boğazı, Karadeniz ile Marmara Denizi\'ni birleştirir.' inputs = tokenizer(text, return_tensors='pt').to(model.device) with torch.no_grad(): _ = model(**inputs) # Expert utilization analiz for name, logits in activations.items(): # logits: [batch, seq, n_experts] top_2_experts = logits.topk(2, dim=-1).indices expert_counts = torch.zeros(8) for e in top_2_experts.flatten(): expert_counts[e] += 1 print(f'{name}: {expert_counts.tolist()}') # Gözlem: bazı layer'larda 1-2 expert'a heavy bias (Türkçe için Mixtral expert'leri "İngilizce" odaklı)
Bu gözlem Türkçe DPO fine-tune'un niye gerekli olduğunu kanıtlar.
✅ Ders 18.2 Özeti — MoE Matematik
MoE FFN layer = N expert + gating + top-k routing. Matematik 3 zor problem çözüyor: gating differentiability (straight-through veya Gumbel), load balancing (Shazeer 2017 aux loss veya DeepSeek 2024 bias trick), capacity (factor + drop tokens). PyTorch sıfırdan implementation mümkün ~80 satır. FLOP: MoE aktif params kadar (sparse). Memory: tüm params (dense). Mixtral 47B = 94GB bf16, DeepSeek-V3 671B = 1.3TB. Türkçe deneyi: Mixtral'da expert utilization Türkçe için biased (İngilizce-odaklı pre-training etkisi). Sonraki ders: DeepSeek-V3 inovasyonları — auxiliary-loss-free, MLA, multi-token prediction.

Sonraki Ders: DeepSeek-V3 İnovasyonları#

Ders 18.3'te DeepSeek-V3'ün 3 kritik yeniliği. Multi-head Latent Attention (MLA) — attention'a sparse benzeri optimizasyon. Auxiliary-loss-free load balancing — gating'i daha temiz çözen yaklaşım. Multi-token prediction — bir adım yerine 2-3 token tahmin etmek. Bu inovasyonların matematiği, Türkçe için ne anlama geliyor.

Frequently Asked Questions

Trade-off: **Top-k=1**: - Compute: halved (only 1 expert) - Memory: same - Quality: slightly lower than top-k=2 (-%2-3 benchmark) **Top-k=2** (modern norm): - Compute: 2× (2 experts active) - Memory: same - Quality: best **Top-k=8** (DeepSeek-V3): - Compute: 8× (but small experts) - Memory: same - Quality: maximum diversity, best **Decision**: DeepSeek-V3's 256 small experts + top-8 approach is 2024 trend. 'Fine-grained expert' philosophy halfway. Medium scale (10-30B param): top-2. Large scale (100B+): top-8 + many small experts.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content