İçeriğe geç

AdamW + Learning Rate Schedule: Modern LLM Optimization'ın Matematik Anatomisi

Modern LLM optimization: SGD'den Adam'a, Adam'dan AdamW'ye evrim. Loshchilov 2019 weight decay decoupling. Momentum (β1=0.9) + variance estimate (β2=0.95) intuition. Learning rate schedules: cosine decay, linear decay, warmup gerekli. Gradient clipping, mixed precision training, hyperparameter pitfalls.

Şükrü Yusuf KAYA
70 dakikalık okuma
İleri
AdamW + Learning Rate Schedule: Modern LLM Optimization'ın Matematik Anatomisi
🎯 AdamW — modern LLM optimization'un standardı
GPT-3, Llama-3, Mistral, Gemini — hepsi AdamW kullanır. Loshchilov 2019'un paper'ı 'Decoupled Weight Decay Regularization' Adam optimizer'a küçük bir değişiklik yaptı: weight decay'i gradient'ten ayır. Sonuç: 7B+ model training'de stability + generalization dramatic improvement. Modern LLM optimization'da AdamW + cosine learning rate schedule + warmup + gradient clipping fiili standard. 70 dakika sonra: AdamW matematiksel anatomisini, hyperparameter seçim intuition'ını, learning rate schedule'ün gradient flow üzerindeki etkisini derinlemesine kavramış olacaksın.

Ders Haritası (10 Bölüm)#

  1. SGD'den Adam'a evrim — momentum, adaptive lr
  2. Adam (Kingma 2014) — first moment + second moment
  3. L2 vs weight decay — niye farklı
  4. AdamW (Loshchilov 2019) — decoupling weight decay
  5. Hyperparameters — β1, β2, ε, weight_decay seçimi
  6. Learning rate schedule — warmup + cosine decay
  7. Niye warmup — adaptive lr stabilization
  8. Gradient clipping — explosion prevention
  9. Mixed precision considerations — fp32 master weights
  10. Llama-3 hyperparameters — production values

1-4. SGD → Adam → AdamW#

1.1 SGD#

Standard gradient descent:
θ_{t+1} = θ_t - lr × ∇L(θ_t)
Problem: tüm parametreler aynı lr. Sparse features için yetersiz.

1.2 SGD + momentum#

m_{t+1} = β × m_t + ∇L(θ_t) θ_{t+1} = θ_t - lr × m_{t+1}
Momentum: gradient'in 'inertia'sı. Olumsuz gradient'lerin etkisini azaltır, persistent direction'ı kuvvetlendirir.

1.3 Adam (Kingma 2014)#

İki estimate:
  • First moment (mean): m_t = β1 × m_{t-1} + (1-β1) × g_t
  • Second moment (variance): v_t = β2 × v_{t-1} + (1-β2) × g_t²
Bias correction:
  • m_hat = m_t / (1 - β1^t)
  • v_hat = v_t / (1 - β2^t)
Update:
θ_{t+1} = θ_t - lr × m_hat / (sqrt(v_hat) + ε)
Intuition: her parameter için lr adaptive scaled.

1.4 Adam + L2 weight decay#

Classic L2 regularization:
L_total = L_task + (λ/2) × ||θ||²
Gradient: ∇L_total = ∇L_task + λ × θ
Adam apply edildiğinde:
m_t = β1 × m_{t-1} + (1-β1) × (∇L_task + λθ_t)
Problem: weight decay term adaptive scaling'ten geçiyor → effective decay her parameter için farklı.

1.5 AdamW (Loshchilov 2019)#

Weight decay'i gradient'ten ayır:
# Gradient update (no weight decay in gradient) m_t = β1 × m_{t-1} + (1-β1) × g_t v_t = β2 × v_{t-1} + (1-β2) × g_t² m_hat = m_t / (1 - β1^t) v_hat = v_t / (1 - β2^t) # Apply update + decoupled weight decay θ_{t+1} = θ_t - lr × (m_hat / (sqrt(v_hat) + ε) + weight_decay × θ_t)
Weight decay direct uygulanır θ'ya, gradient'ten geçmez.

1.6 Empirical improvement#

Loshchilov 2019: AdamW vs Adam-L2 on transformer training:
  • AdamW: better generalization
  • Lower validation loss
  • Modern LLM models all AdamW

1.7 PyTorch AdamW#

import torch.optim as optim optimizer = optim.AdamW( model.parameters(), lr=3e-4, betas=(0.9, 0.95), # β1, β2 eps=1e-5, weight_decay=0.1, )

6-9. Learning Rate Schedule#

6.1 Niye constant lr yetmez#

Deep network training:
  • Erken: büyük updates (model far from optimum)
  • Geç: küçük updates (fine-tune around optimum)
Constant lr ya başta unstable ya sonda yetersiz.

6.2 Linear warmup#

Çok küçük lr → max lr, lineer artış:
lr(t) = max_lr × min(1, t / warmup_steps)
Niye warmup gerekli:
  • AdamW second moment v_t başlangıçta unstable
  • Çok küçük v_t → ratio çok büyük → updates patlama
  • Warmup ile v_t stabilize

6.3 Cosine decay#

Main training phase:
lr(t) = lr_min + 0.5 × (max_lr - lr_min) × (1 + cos(π × (t - warmup) / (total - warmup)))
lr smooth decay max → min over training. Cosine shape stable, no harsh discontinuities.

6.4 Linear decay#

Alternative: linear instead of cosine. Empirically similar quality.

6.5 Llama-3 schedule#

  • warmup: 2000 steps (linear)
  • decay: cosine, max 3e-4 → min 3e-5 (10x)
  • total: 1.4M steps

6.6 Gradient clipping#

Prevent gradient explosion:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
If gradient norm > 1.0, scale all gradients to norm = 1.0.
Llama-3: max_norm=1.0. Stabilizes training in NaN-prone regions.

6.7 Mixed precision#

Llama-3 production:
  • Forward + backward: bf16 (BFloat16)
  • Optimizer state: fp32 (full precision)
  • Master weights: fp32
  • Gradient: bf16 → fp32 conversion
bf16 forward: 2x faster, half memory. fp32 master: preserve precision (avoid catastrophic accumulation errors).

6.8 Production checklist#

from torch.optim.lr_scheduler import LambdaLR import math def cosine_schedule_with_warmup(step, warmup_steps, total_steps, min_ratio=0.1): if step < warmup_steps: return step / warmup_steps progress = (step - warmup_steps) / (total_steps - warmup_steps) return min_ratio + 0.5 * (1 - min_ratio) * (1 + math.cos(math.pi * progress)) scheduler = LambdaLR(optimizer, lambda step: cosine_schedule_with_warmup(step, 2000, 1_400_000))
✅ Ders 11.2 Özeti — AdamW Optimizer
Modern LLM optimization: AdamW (Loshchilov 2019, Adam + weight decay decoupled). Llama-3 hyperparams: lr 3e-4, β1=0.9, β2=0.95, weight_decay=0.1. Schedule: linear warmup (2000 steps) + cosine decay (3e-4 → 3e-5). Warmup neden zorunlu: AdamW v_t stabilization. Gradient clipping max_norm=1.0 explosion prevention. Mixed precision: bf16 forward/backward + fp32 master weights. Ders 11.3'te Modül 11 capstone'a (mini Llama-3 training) geçeceğiz.

Sıradaki Ders: Mini Llama-3 Training Capstone#

Ders 11.3: 100M param Llama-3-mini training sıfırdan — corpus, tokenize, AdamW, schedule, single GPU H100, validation loss monitoring, checkpoint.

Sık Sorulan Sorular

Llama paper: β2 = 0.95 large-scale LLM training'de daha stable. Lower β2 → second moment estimate daha responsive → büyük gradient değişikliklerine hızlı adapt. β2 = 0.999 küçük model OK ama 7B+ model'lerde instability.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler