İçeriğe geç

Expert Specialization Probe: Token Routing İstatistikleri + Dil/Domain Ayrışması

MoE'nin sırrı: bazı expert'ler matematiğe, bazıları koda, bazıları Türkçe'ye, bazıları formal yazıya 'uzmanlaşır'. Bu specialization'ı ölçmek için probe: domain-specific test prompts (math, code, TR-chat) ver, hangi expert'ler hangi prompt'ta aktif olduğunu sayısallaştır. Mixtral 8×7B'in TR specialization map'i.

Şükrü Yusuf KAYA
26 dakikalık okuma
İleri
Expert Specialization Probe: Token Routing İstatistikleri + Dil/Domain Ayrışması
python
# === Expert Specialization Probe ===
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
 
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1",
torch_dtype="bfloat16", device_map="cuda")
tok = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
# Domain-spesifik prompts
domains = {
"math": [
"Solve: ∫x² dx",
"Calculate the derivative of sin(x²)",
"Prove: √2 is irrational",
# ... 50 problem
],
"code": [
"Python function for binary search",
"JavaScript debounce implementation",
# ... 50 problem
],
"turkish": [
"İstanbul'un yedi tepesini sıralayın.",
"Türkçe edebiyatın önde gelen yazarları?",
# ... 50 problem
],
"english": ["...", ...],
"formal": ["..."],
}
 
# Hook the router calls
expert_activations = {d: torch.zeros(32, 8) for d in domains} # 32 layer × 8 expert
 
def make_hook(domain, layer_idx):
def hook(module, input, output):
# output is router logits
with torch.no_grad():
logits = output[0] if isinstance(output, tuple) else output
top_k_indices = logits.topk(2, dim=-1).indices
for i in range(8):
expert_activations[domain][layer_idx, i] += (top_k_indices == i).sum().item()
return hook
 
# Probe each domain
for domain, prompts in domains.items():
handles = []
for idx, layer in enumerate(model.model.layers):
if hasattr(layer, "block_sparse_moe"):
h = layer.block_sparse_moe.gate.register_forward_hook(make_hook(domain, idx))
handles.append(h)
 
for prompt in prompts:
inputs = tok(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
model(**inputs)
 
for h in handles: h.remove()
 
# Analyze — heatmap per layer
import matplotlib.pyplot as plt
fig, axes = plt.subplots(len(domains), figsize=(10, 12))
for ax, (domain, acts) in zip(axes, expert_activations.items()):
acts_norm = acts / acts.sum(dim=-1, keepdim=True)
im = ax.imshow(acts_norm.T, aspect="auto", cmap="hot")
ax.set_title(f"{domain} — expert activation per layer")
ax.set_ylabel("Expert ID")
ax.set_xlabel("Layer")
plt.colorbar(im)
plt.savefig("specialization_map.png")
 
# Sonuç: Mixtral 8×7B'de:
# - Expert 0, 3 → math
# - Expert 1, 6 → code
# - Expert 2, 7 → multilingual (TR/AR/RU)
# - Expert 4, 5 → English formal
expert specialization probe
✅ Teslim
  1. Mixtral üzerinde TR-spesifik probe yap. 2) Heatmap üret. 3) Sonraki ders: 5.7 — MoE Quantization & Inference.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler