Mixtral 8×7B / 8×22B FT: Router Collapse Problemi + Aux Loss Weight Kalibrasyon
Mixtral'in FT'sinde en sık karşılaşılan bug: **router collapse** — eğitim ilerledikçe bir expert dominat olur, diğerleri dead. Capacity overflow, aux loss weight'in dinamik adaptasyonu, expert balance metrics ölçümü, FSDP + MoE uyumu (expert parallelism). 4×H100 80GB Mixtral 8×7B QLoRA reçetesi (~4 saat).
Şükrü Yusuf KAYA
32 dakikalık okuma
İleripython
# === Expert Balance Metrics ölçümü ===# Mixtral router'ı her step için expert distribution'unu loglaimport torch class RouterMonitor: def __init__(self, num_experts=8): self.num_experts = num_experts self.token_counts = torch.zeros(num_experts) self.weight_sums = torch.zeros(num_experts) def update(self, router_logits, expert_indices): """Her FFN forward call'unda çağır.""" with torch.no_grad(): for i in range(self.num_experts): mask = (expert_indices == i) self.token_counts[i] += mask.sum().item() self.weight_sums[i] += router_logits[mask].sum().item() def report(self): total = self.token_counts.sum() if total == 0: return {} f = self.token_counts / total return { "expert_load_ratio": f.tolist(), "expert_load_std": f.std().item(), # 0 = perfect balance "expert_load_max": f.max().item(), # >0.4 → router collapse riski "expert_load_min": f.min().item(), # <0.05 → dead expert } # Training callback ile entegre etfrom transformers import TrainerCallback class MoEMonitorCallback(TrainerCallback): def __init__(self, model): self.monitors = {} for layer_idx, layer in enumerate(model.model.layers): if hasattr(layer, "block_sparse_moe"): self.monitors[layer_idx] = RouterMonitor() def on_step_end(self, args, state, control, **kwargs): if state.global_step % 50 == 0: for idx, mon in self.monitors.items(): report = mon.report() print(f"layer-{idx}: load_std={report['expert_load_std']:.3f}, max={report['expert_load_max']:.3f}") # Sağlıklı bir Mixtral FT'inde her layer'da:# - load_std ~0.03-0.08 (uniform'a yakın)# - load_max ~0.15-0.20 (ideal: 1/8 = 0.125)# - load_min ~0.08-0.12 # Router collapse'da:# - load_std > 0.15# - load_max > 0.40 (bir expert tüm token'ları yutuyor)# - load_min < 0.02 (dead expert)expert balance monitoring
1. Router Collapse Mitigation#
| Sebep | Çözüm |
|---|---|
| Aux loss weight çok düşük | α'yı 0.001 → 0.01'e çık |
| Learning rate çok yüksek | LR'i 1/2-1/3'e indir |
| Dataset bias (tek-dil dominat) | Veri mix balance |
| FT epoch sayısı yüksek | Epoch azalt (overfitting risk MoE'de yüksek) |
| Router frozen değil | LoRA target_modules'de gate |
Cookbook'un kuralı (Mixtral 8×7B FT):
- α = 0.01 (training default'tan 10× yüksek)
- LR = 1e-4 (Llama'da 2e-4'ün yarısı)
- Target modules: ONLY (gate ve experts'i dondur — minimum kalite kaybı için)
q_proj, k_proj, v_proj, o_proj - Epoch ≤ 2
✅ Teslim
- Mixtral 8×7B mini-FT lab koş. 2) RouterMonitor ile balance metrics logla. 3) Sonraki ders: 5.3 — DeepSeek-V3 / R1 (671B, 37B active).
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
İlgili İçerikler
Part 0 — Engineering Foundations
Fine-Tuning Cookbook'a Hoş Geldin: Sistematik, Stage Taksonomisi ve Reproducibility Kontratı
Öğrenmeye BaşlaPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags ve Deterministic CUDA — 'Sende Niye Çalışıyor Bende Çalışmıyor' Sorununu Bitir
Öğrenmeye BaşlaPart 0 — Engineering Foundations