İçeriğe geç

Mixtral 8×7B / 8×22B FT: Router Collapse Problemi + Aux Loss Weight Kalibrasyon

Mixtral'in FT'sinde en sık karşılaşılan bug: **router collapse** — eğitim ilerledikçe bir expert dominat olur, diğerleri dead. Capacity overflow, aux loss weight'in dinamik adaptasyonu, expert balance metrics ölçümü, FSDP + MoE uyumu (expert parallelism). 4×H100 80GB Mixtral 8×7B QLoRA reçetesi (~4 saat).

Şükrü Yusuf KAYA
32 dakikalık okuma
İleri
Mixtral 8×7B / 8×22B FT: Router Collapse Problemi + Aux Loss Weight Kalibrasyon
python
# === Expert Balance Metrics ölçümü ===
# Mixtral router'ı her step için expert distribution'unu logla
import torch
 
class RouterMonitor:
def __init__(self, num_experts=8):
self.num_experts = num_experts
self.token_counts = torch.zeros(num_experts)
self.weight_sums = torch.zeros(num_experts)
 
def update(self, router_logits, expert_indices):
"""Her FFN forward call'unda çağır."""
with torch.no_grad():
for i in range(self.num_experts):
mask = (expert_indices == i)
self.token_counts[i] += mask.sum().item()
self.weight_sums[i] += router_logits[mask].sum().item()
 
def report(self):
total = self.token_counts.sum()
if total == 0: return {}
f = self.token_counts / total
return {
"expert_load_ratio": f.tolist(),
"expert_load_std": f.std().item(), # 0 = perfect balance
"expert_load_max": f.max().item(), # >0.4 → router collapse riski
"expert_load_min": f.min().item(), # <0.05 → dead expert
}
 
# Training callback ile entegre et
from transformers import TrainerCallback
 
class MoEMonitorCallback(TrainerCallback):
def __init__(self, model):
self.monitors = {}
for layer_idx, layer in enumerate(model.model.layers):
if hasattr(layer, "block_sparse_moe"):
self.monitors[layer_idx] = RouterMonitor()
 
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % 50 == 0:
for idx, mon in self.monitors.items():
report = mon.report()
print(f"layer-{idx}: load_std={report['expert_load_std']:.3f}, max={report['expert_load_max']:.3f}")
 
# Sağlıklı bir Mixtral FT'inde her layer'da:
# - load_std ~0.03-0.08 (uniform'a yakın)
# - load_max ~0.15-0.20 (ideal: 1/8 = 0.125)
# - load_min ~0.08-0.12
 
# Router collapse'da:
# - load_std > 0.15
# - load_max > 0.40 (bir expert tüm token'ları yutuyor)
# - load_min < 0.02 (dead expert)
expert balance monitoring

1. Router Collapse Mitigation#

SebepÇözüm
Aux loss weight çok düşükα'yı 0.001 → 0.01'e çık
Learning rate çok yüksekLR'i 1/2-1/3'e indir
Dataset bias (tek-dil dominat)Veri mix balance
FT epoch sayısı yüksekEpoch azalt (overfitting risk MoE'de yüksek)
Router frozen değilLoRA target_modules'de
gate
ekleme — router'ı dondur veya çok küçük LoRA
Cookbook'un kuralı (Mixtral 8×7B FT):
  • α = 0.01 (training default'tan 10× yüksek)
  • LR = 1e-4 (Llama'da 2e-4'ün yarısı)
  • Target modules:
    q_proj, k_proj, v_proj, o_proj
    ONLY (gate ve experts'i dondur — minimum kalite kaybı için)
  • Epoch ≤ 2
✅ Teslim
  1. Mixtral 8×7B mini-FT lab koş. 2) RouterMonitor ile balance metrics logla. 3) Sonraki ders: 5.3 — DeepSeek-V3 / R1 (671B, 37B active).

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler