Mixtral 8×7B / 8×22B FT: Router Collapse Problem + Aux Loss Weight Calibration
Most common Mixtral FT bug: **router collapse** — one expert dominates, others dead as training progresses. Capacity overflow, dynamic aux loss adaptation, expert balance metrics, FSDP + MoE compatibility (expert parallelism). Mixtral 8×7B QLoRA recipe on 4×H100 80GB (~4h).
Şükrü Yusuf KAYA
32 min read
Advancedpython
# === Expert Balance Metrics ölçümü ===# Mixtral router'ı her step için expert distribution'unu loglaimport torch class RouterMonitor: def __init__(self, num_experts=8): self.num_experts = num_experts self.token_counts = torch.zeros(num_experts) self.weight_sums = torch.zeros(num_experts) def update(self, router_logits, expert_indices): """Her FFN forward call'unda çağır.""" with torch.no_grad(): for i in range(self.num_experts): mask = (expert_indices == i) self.token_counts[i] += mask.sum().item() self.weight_sums[i] += router_logits[mask].sum().item() def report(self): total = self.token_counts.sum() if total == 0: return {} f = self.token_counts / total return { "expert_load_ratio": f.tolist(), "expert_load_std": f.std().item(), # 0 = perfect balance "expert_load_max": f.max().item(), # >0.4 → router collapse riski "expert_load_min": f.min().item(), # <0.05 → dead expert } # Training callback ile entegre etfrom transformers import TrainerCallback class MoEMonitorCallback(TrainerCallback): def __init__(self, model): self.monitors = {} for layer_idx, layer in enumerate(model.model.layers): if hasattr(layer, "block_sparse_moe"): self.monitors[layer_idx] = RouterMonitor() def on_step_end(self, args, state, control, **kwargs): if state.global_step % 50 == 0: for idx, mon in self.monitors.items(): report = mon.report() print(f"layer-{idx}: load_std={report['expert_load_std']:.3f}, max={report['expert_load_max']:.3f}") # Sağlıklı bir Mixtral FT'inde her layer'da:# - load_std ~0.03-0.08 (uniform'a yakın)# - load_max ~0.15-0.20 (ideal: 1/8 = 0.125)# - load_min ~0.08-0.12 # Router collapse'da:# - load_std > 0.15# - load_max > 0.40 (bir expert tüm token'ları yutuyor)# - load_min < 0.02 (dead expert)expert balance monitoring
1. Router Collapse Mitigation#
| Sebep | Çözüm |
|---|---|
| Aux loss weight çok düşük | α'yı 0.001 → 0.01'e çık |
| Learning rate çok yüksek | LR'i 1/2-1/3'e indir |
| Dataset bias (tek-dil dominat) | Veri mix balance |
| FT epoch sayısı yüksek | Epoch azalt (overfitting risk MoE'de yüksek) |
| Router frozen değil | LoRA target_modules'de gate |
Cookbook'un kuralı (Mixtral 8×7B FT):
- α = 0.01 (training default'tan 10× yüksek)
- LR = 1e-4 (Llama'da 2e-4'ün yarısı)
- Target modules: ONLY (gate ve experts'i dondur — minimum kalite kaybı için)
q_proj, k_proj, v_proj, o_proj - Epoch ≤ 2
✅ Teslim
- Mixtral 8×7B mini-FT lab koş. 2) RouterMonitor ile balance metrics logla. 3) Sonraki ders: 5.3 — DeepSeek-V3 / R1 (671B, 37B active).
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Part 0 — Engineering Foundations
Welcome to the Fine-Tuning Cookbook: System, Stage Taxonomy, and the Reproducibility Contract
Start LearningPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags, and Deterministic CUDA — End the 'Works on My Machine' Problem
Start LearningPart 0 — Engineering Foundations