Normalization Revolution: LayerNorm, RMSNorm and Pre-LN vs Post-LN — Cornerstone of Training Stability
Mathematical and systems anatomy of transformer training stability: LayerNorm (Ba 2016) classical formula, RMSNorm (Zhang 2019) — Llama-3's choice, why gain parameter only, computational savings. Pre-LN (modern) vs Post-LN (original Vaswani) trade-off, gradient flow, deep transformer stability. Normalization concerns in Turkish model fine-tuning.
Şükrü Yusuf KAYA
70 min read
Advanced🌡️ Normalization — derin transformer'ın termostatı
100+ layer transformer eğitmek istiyorsun. İlk forward pass: activation magnitudes patlıyor (1e6+) veya sönüyor (1e-6). Gradient flow imkansız, training NaN'lerle çöküyor. Bu, deep neural network'lerin temel sorunu. Çözüm 2016'da Jimmy Ba'nın LayerNorm'u, 2019'da Biao Zhang'ın RMSNorm'u ile geldi. Llama-3'ün her bloğunda 2 normalization katmanı var — toplam 64 normalization (32 blok × 2). Modelin stabilitesini sağlayan görünmez kahraman. 70 dakika sonra: LayerNorm/RMSNorm matematiksel anatomisini, Pre-LN/Post-LN architectural seçiminin perplexity üzerindeki etkisini, modern modellerin RMSNorm + Pre-LN standardını niye seçtiğini derinlemesine kavramış olacaksın.
Ders Haritası (12 Bölüm)#
- Niye normalization — internal covariate shift problemi
- BatchNorm sınırı — sequence model'de neden yetersiz
- LayerNorm (Ba 2016) — formül + intuition
- LayerNorm gain + bias parameters — learnable
- RMSNorm (Zhang 2019) — bias drop, root mean square only
- RMSNorm computational savings — %30 faster
- Pre-LN vs Post-LN — architectural trade-off
- Original Vaswani Post-LN — niye instabil
- Modern Pre-LN — Llama-3, GPT-4 standard
- Gradient flow analysis — deep transformer stability math
- Llama-3 production — RMSNorm + Pre-LN birleşimi
- Türkçe fine-tune — normalization concerns
1-3. Normalization Temelleri#
1.1 Internal covariate shift#
Ioffe & Szegedy 2015: 'internal covariate shift' problemi. Deep network'ün hidden layer'ları training sırasında distribution shift yaşar:
- Her parameter update activation distribution'unu değiştirir
- Sonraki layer her step farklı distribution'a uyum sağlamak zorunda
- Training instable, learning rate küçük olmak zorunda
Normalization: activations'ı sabit distribution'a (zero mean, unit variance) zorlanır.
1.2 BatchNorm (Ioffe 2015)#
İlk büyük başarı. Batch dimension üzerinde normalize:
μ = mean(x, dim=batch) σ = std(x, dim=batch) x_norm = (x - μ) / σ
ResNet'te dramatic etki, ImageNet record.
1.3 BatchNorm sequence model'de problem#
- Batch size dependence: küçük batch'te variance büyük
- Sequence length: variable lengths uyumsuz
- Inference time: batch=1 problem
- Recurrent connections: BatchNorm hangi step'te?
NLP için BatchNorm uygunsuz.
1.4 LayerNorm (Ba 2016)#
Jimmy Ba çözüm: batch yerine feature dimension üzerinde normalize.
For each token (or each sample): μ = mean(x, dim=features) # scalar per sample σ² = var(x, dim=features) # scalar per sample x_norm = (x - μ) / sqrt(σ² + ε) # ε for numerical stability y = γ × x_norm + β # learnable scale + shift
γ (gain) ve β (bias) learnable parameters, her feature için ayrı.
1.5 LayerNorm intuition#
- Her token'ın hidden vector'ünü normalize et
- Mean 0, variance 1
- γ, β ile re-scale (model needs flexibility)
Batch'ten bağımsız → seq model'lere ideal.
5-6. RMSNorm — Llama-3'ün Tercihi#
5.1 Zhang 2019 paper#
'Root Mean Square Layer Normalization'. Insight: LayerNorm'un mean centering kısmı gerekli değil.
RMSNorm:
rms(x) = sqrt(mean(x²)) # only root mean square x_norm = x / rms(x) # no mean subtraction y = γ × x_norm # only gain, no bias
No mean computation, no bias parameter.
5.2 Computational comparison#
LayerNorm per token:
1. Compute mean (sum + divide) 2. Subtract mean (broadcast) 3. Compute variance (square, sum, divide) 4. Compute std (sqrt) 5. Divide 6. Multiply by gain 7. Add bias
RMSNorm per token:
1. Compute mean of squares (square, sum, divide) 2. Compute rms (sqrt) 3. Divide 4. Multiply by gain
RMSNorm: 4 vs 7 ops. 30% faster in practice.
5.3 Quality preservation#
Zhang 2019 empirical: RMSNorm LayerNorm-comparable quality in most settings. Mean centering sometimes critical, sometimes not.
5.4 Adoption#
- T5: RMSNorm
- Llama-1, Llama-2, Llama-3, Mistral, Mixtral, Qwen: RMSNorm
- GPT-4 (tahmini): LayerNorm or RMSNorm (unclear)
- BERT, GPT-2, GPT-3: LayerNorm (legacy)
Modern trend: RMSNorm dominance.
5.5 PyTorch implementation#
class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.gain = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): # Last dim normalization rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return self.gain * x / rms
Sadece 4 satır. Production kullanım için triton/fused kernel optimize.
5.6 Niye Llama-3 RMSNorm#
- %30 faster forward pass
- Quality difference minor
- Modern best practice
- Mean centering çoğu durumda gereksiz
- Parameter savings (no bias) — milyonlar params
7-10. Pre-LN vs Post-LN#
7.1 Original Vaswani 2017 (Post-LN)#
Original transformer paper'da:
# Post-LN architecture x = x + Attention(LayerNorm(x)) # LN inside, after sublayer x = LayerNorm(x) # LN after residual x = x + FFN(LayerNorm(x)) x = LayerNorm(x)
Norm her sublayer'ın çıkışında (post = after).
7.2 Pre-LN (modern)#
# Pre-LN architecture x = x + Attention(LayerNorm(x)) # LN inside, before sublayer x = x + FFN(LayerNorm(x)) # LN inside, before sublayer
Norm her sublayer'ın girişinde (pre = before).
7.3 Niye Pre-LN modern standard#
Xiong et al. 2020 paper: 'On Layer Normalization in the Transformer Architecture'.
Post-LN problem:
- Gradient magnitudes deeper layer'larda patlar
- Learning rate warmup zorunlu (yoksa training NaN)
- 12+ layer deeper transformer'larda training stable değil
Pre-LN avantajları:
- Gradient magnitudes layer'lar boyunca stabil
- Warmup gerekmez (veya minimal)
- 100+ layer transformer'lar mümkün
7.4 Math: gradient flow#
Pre-LN'de residual connection direct:
x_L = x_0 + Σ_i sublayer_i(LN(x_i))
Gradient direct path: ∂x_L / ∂x_0 ≈ I (identity). Deep network'te bile gradient kaybolmaz.
Post-LN'de:
x_L = LN(x_{L-1} + sublayer(LN(x_{L-1})))
LN gradient'i scale eder. Compound effect: deep'te gradient küçülür.
7.5 Modern model preferences#
Pre-LN:
- GPT-2 (later versions)
- GPT-3+
- Llama-1, Llama-2, Llama-3
- Mistral
- Mixtral
- Claude (tahmini)
- BLOOM
Post-LN:
- Original BERT
- T5
- Original Vaswani transformer
7.6 Practical impact#
Pre-LN ile training:
- Learning rate: 1e-4 (no warmup needed)
- Stable convergence
- 32-layer Llama-3 8B trains smoothly
Post-LN ile:
- Learning rate: 1e-5 (warmup zorunlu, 1000+ steps)
- Often NaN at scale
- Deep model training risky
7.7 Llama-3 specific#
For each transformer block: h = h + Attention(RMSNorm(h)) # Pre-LN + RMSNorm h = h + FFN(RMSNorm(h)) # Pre-LN + RMSNorm
RMSNorm + Pre-LN kombinasyonu Llama-3'ün stability'sinin temel taşı.
python
import torchimport torch.nn as nn class RMSNorm(nn.Module): """Llama-3 style RMSNorm.""" def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): # Compute RMS along last dim rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return self.weight * x / rms class LayerNorm(nn.Module): """Classic LayerNorm.""" def __init__(self, dim, eps=1e-6): super().__init__() self.gain = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim)) self.eps = eps def forward(self, x): mean = x.mean(dim=-1, keepdim=True) var = x.var(dim=-1, keepdim=True, unbiased=False) return self.gain * (x - mean) / torch.sqrt(var + self.eps) + self.bias # Pre-LN block (modern)class PreLNBlock(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.norm1 = RMSNorm(d_model) self.norm2 = RMSNorm(d_model) self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) self.ffn = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.SiLU(), nn.Linear(4 * d_model, d_model), ) def forward(self, x): h = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] # Pre-LN attention h = h + self.ffn(self.norm2(h)) # Pre-LN FFN return h # Performance test: RMSNorm vs LayerNormimport time d_model = 4096batch, seq = 32, 1024x = torch.randn(batch, seq, d_model, device='cuda') rms = RMSNorm(d_model).cuda()ln = LayerNorm(d_model).cuda() def bench(fn, n=100): torch.cuda.synchronize() start = time.time() for _ in range(n): _ = fn(x) torch.cuda.synchronize() return (time.time() - start) / n * 1000 print(f"LayerNorm: {bench(ln):.2f} ms")print(f"RMSNorm: {bench(rms):.2f} ms")print(f"Speedup: {bench(ln) / bench(rms):.2f}x")RMSNorm + LayerNorm comparison — production benchmark
✅ Ders 10.1 Özeti — Normalization
Transformer training stability'sinin temel taşı normalization. LayerNorm (Ba 2016): feature dim normalize, mean centering + variance scaling, learnable gain + bias. RMSNorm (Zhang 2019): mean centering YOK, sadece root mean square + gain. %30 faster, parameter savings. Llama-3 choice: RMSNorm. Pre-LN vs Post-LN: original Vaswani Post-LN deep transformer'da instable (gradient explosion). Pre-LN modern standard: gradient flow stable, learning rate warmup gerekmez, 100+ layer training mümkün. Llama-3 architecture: RMSNorm + Pre-LN — modern standard combo. Ders 10.2'de SwiGLU activation function'a geçeceğiz.
Sıradaki Ders: SwiGLU Activation Function#
Ders 10.2: SiLU + GLU = SwiGLU (Shazeer 2020), niye ReLU/GeLU yerine, Llama-3 implementation, FFN dimensions, performance.
Frequently Asked Questions
Generally yes — preferred in modern models. BUT: small models or specific tasks may benefit from LayerNorm's mean centering. Production: Llama/Mistral use RMSNorm, BERT/T5 use LayerNorm.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Module 0: Course Framework & Workshop Setup
Who Is an LLM Engineer? The AI Engineering Career Ladder from Junior to Staff
Start LearningModule 0: Course Framework & Workshop Setup
Course Philosophy: Why This Path, Why This Order — The Skeleton of an 8-Month Curriculum
Start LearningModule 0: Course Framework & Workshop Setup