İçeriğe geç

Normalization Devrim: LayerNorm, RMSNorm ve Pre-LN vs Post-LN — Training Stabilitesinin Temel Taşı

Transformer training stabilitesinin matematiksel ve sistemsel anatomi: LayerNorm (Ba 2016) klasik formülü, RMSNorm (Zhang 2019) — Llama-3 tercihi, niye gain parameter only, computational savings. Pre-LN (modern) vs Post-LN (original Vaswani) trade-off, gradient flow, deep transformer stability. Türkçe model fine-tune'da normalization concerns.

Şükrü Yusuf KAYA
70 dakikalık okuma
İleri
Normalization Devrim: LayerNorm, RMSNorm ve Pre-LN vs Post-LN — Training Stabilitesinin Temel Taşı
🌡️ Normalization — derin transformer'ın termostatı
100+ layer transformer eğitmek istiyorsun. İlk forward pass: activation magnitudes patlıyor (1e6+) veya sönüyor (1e-6). Gradient flow imkansız, training NaN'lerle çöküyor. Bu, deep neural network'lerin temel sorunu. Çözüm 2016'da Jimmy Ba'nın LayerNorm'u, 2019'da Biao Zhang'ın RMSNorm'u ile geldi. Llama-3'ün her bloğunda 2 normalization katmanı var — toplam 64 normalization (32 blok × 2). Modelin stabilitesini sağlayan görünmez kahraman. 70 dakika sonra: LayerNorm/RMSNorm matematiksel anatomisini, Pre-LN/Post-LN architectural seçiminin perplexity üzerindeki etkisini, modern modellerin RMSNorm + Pre-LN standardını niye seçtiğini derinlemesine kavramış olacaksın.

Ders Haritası (12 Bölüm)#

  1. Niye normalization — internal covariate shift problemi
  2. BatchNorm sınırı — sequence model'de neden yetersiz
  3. LayerNorm (Ba 2016) — formül + intuition
  4. LayerNorm gain + bias parameters — learnable
  5. RMSNorm (Zhang 2019) — bias drop, root mean square only
  6. RMSNorm computational savings — %30 faster
  7. Pre-LN vs Post-LN — architectural trade-off
  8. Original Vaswani Post-LN — niye instabil
  9. Modern Pre-LN — Llama-3, GPT-4 standard
  10. Gradient flow analysis — deep transformer stability math
  11. Llama-3 production — RMSNorm + Pre-LN birleşimi
  12. Türkçe fine-tune — normalization concerns

1-3. Normalization Temelleri#

1.1 Internal covariate shift#

Ioffe & Szegedy 2015: 'internal covariate shift' problemi. Deep network'ün hidden layer'ları training sırasında distribution shift yaşar:
  • Her parameter update activation distribution'unu değiştirir
  • Sonraki layer her step farklı distribution'a uyum sağlamak zorunda
  • Training instable, learning rate küçük olmak zorunda
Normalization: activations'ı sabit distribution'a (zero mean, unit variance) zorlanır.

1.2 BatchNorm (Ioffe 2015)#

İlk büyük başarı. Batch dimension üzerinde normalize:
μ = mean(x, dim=batch) σ = std(x, dim=batch) x_norm = (x - μ) / σ
ResNet'te dramatic etki, ImageNet record.

1.3 BatchNorm sequence model'de problem#

  • Batch size dependence: küçük batch'te variance büyük
  • Sequence length: variable lengths uyumsuz
  • Inference time: batch=1 problem
  • Recurrent connections: BatchNorm hangi step'te?
NLP için BatchNorm uygunsuz.

1.4 LayerNorm (Ba 2016)#

Jimmy Ba çözüm: batch yerine feature dimension üzerinde normalize.
For each token (or each sample): μ = mean(x, dim=features) # scalar per sample σ² = var(x, dim=features) # scalar per sample x_norm = (x - μ) / sqrt(σ² + ε) # ε for numerical stability y = γ × x_norm + β # learnable scale + shift
γ (gain) ve β (bias) learnable parameters, her feature için ayrı.

1.5 LayerNorm intuition#

  • Her token'ın hidden vector'ünü normalize et
  • Mean 0, variance 1
  • γ, β ile re-scale (model needs flexibility)
Batch'ten bağımsız → seq model'lere ideal.

5-6. RMSNorm — Llama-3'ün Tercihi#

5.1 Zhang 2019 paper#

'Root Mean Square Layer Normalization'. Insight: LayerNorm'un mean centering kısmı gerekli değil.
RMSNorm:
rms(x) = sqrt(mean(x²)) # only root mean square x_norm = x / rms(x) # no mean subtraction y = γ × x_norm # only gain, no bias
No mean computation, no bias parameter.

5.2 Computational comparison#

LayerNorm per token:
1. Compute mean (sum + divide) 2. Subtract mean (broadcast) 3. Compute variance (square, sum, divide) 4. Compute std (sqrt) 5. Divide 6. Multiply by gain 7. Add bias
RMSNorm per token:
1. Compute mean of squares (square, sum, divide) 2. Compute rms (sqrt) 3. Divide 4. Multiply by gain
RMSNorm: 4 vs 7 ops. 30% faster in practice.

5.3 Quality preservation#

Zhang 2019 empirical: RMSNorm LayerNorm-comparable quality in most settings. Mean centering sometimes critical, sometimes not.

5.4 Adoption#

  • T5: RMSNorm
  • Llama-1, Llama-2, Llama-3, Mistral, Mixtral, Qwen: RMSNorm
  • GPT-4 (tahmini): LayerNorm or RMSNorm (unclear)
  • BERT, GPT-2, GPT-3: LayerNorm (legacy)
Modern trend: RMSNorm dominance.

5.5 PyTorch implementation#

class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.gain = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): # Last dim normalization rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return self.gain * x / rms
Sadece 4 satır. Production kullanım için triton/fused kernel optimize.

5.6 Niye Llama-3 RMSNorm#

  • %30 faster forward pass
  • Quality difference minor
  • Modern best practice
  • Mean centering çoğu durumda gereksiz
  • Parameter savings (no bias) — milyonlar params

7-10. Pre-LN vs Post-LN#

7.1 Original Vaswani 2017 (Post-LN)#

Original transformer paper'da:
# Post-LN architecture x = x + Attention(LayerNorm(x)) # LN inside, after sublayer x = LayerNorm(x) # LN after residual x = x + FFN(LayerNorm(x)) x = LayerNorm(x)
Norm her sublayer'ın çıkışında (post = after).

7.2 Pre-LN (modern)#

# Pre-LN architecture x = x + Attention(LayerNorm(x)) # LN inside, before sublayer x = x + FFN(LayerNorm(x)) # LN inside, before sublayer
Norm her sublayer'ın girişinde (pre = before).

7.3 Niye Pre-LN modern standard#

Xiong et al. 2020 paper: 'On Layer Normalization in the Transformer Architecture'.
Post-LN problem:
  • Gradient magnitudes deeper layer'larda patlar
  • Learning rate warmup zorunlu (yoksa training NaN)
  • 12+ layer deeper transformer'larda training stable değil
Pre-LN avantajları:
  • Gradient magnitudes layer'lar boyunca stabil
  • Warmup gerekmez (veya minimal)
  • 100+ layer transformer'lar mümkün

7.4 Math: gradient flow#

Pre-LN'de residual connection direct:
x_L = x_0 + Σ_i sublayer_i(LN(x_i))
Gradient direct path: ∂x_L / ∂x_0 ≈ I (identity). Deep network'te bile gradient kaybolmaz.
Post-LN'de:
x_L = LN(x_{L-1} + sublayer(LN(x_{L-1})))
LN gradient'i scale eder. Compound effect: deep'te gradient küçülür.

7.5 Modern model preferences#

Pre-LN:
  • GPT-2 (later versions)
  • GPT-3+
  • Llama-1, Llama-2, Llama-3
  • Mistral
  • Mixtral
  • Claude (tahmini)
  • BLOOM
Post-LN:
  • Original BERT
  • T5
  • Original Vaswani transformer

7.6 Practical impact#

Pre-LN ile training:
  • Learning rate: 1e-4 (no warmup needed)
  • Stable convergence
  • 32-layer Llama-3 8B trains smoothly
Post-LN ile:
  • Learning rate: 1e-5 (warmup zorunlu, 1000+ steps)
  • Often NaN at scale
  • Deep model training risky

7.7 Llama-3 specific#

For each transformer block: h = h + Attention(RMSNorm(h)) # Pre-LN + RMSNorm h = h + FFN(RMSNorm(h)) # Pre-LN + RMSNorm
RMSNorm + Pre-LN kombinasyonu Llama-3'ün stability'sinin temel taşı.
python
import torch
import torch.nn as nn
 
class RMSNorm(nn.Module):
"""Llama-3 style RMSNorm."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
# Compute RMS along last dim
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return self.weight * x / rms
 
 
class LayerNorm(nn.Module):
"""Classic LayerNorm."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.gain = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return self.gain * (x - mean) / torch.sqrt(var + self.eps) + self.bias
 
 
# Pre-LN block (modern)
class PreLNBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.SiLU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x):
h = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] # Pre-LN attention
h = h + self.ffn(self.norm2(h)) # Pre-LN FFN
return h
 
 
# Performance test: RMSNorm vs LayerNorm
import time
 
d_model = 4096
batch, seq = 32, 1024
x = torch.randn(batch, seq, d_model, device='cuda')
 
rms = RMSNorm(d_model).cuda()
ln = LayerNorm(d_model).cuda()
 
def bench(fn, n=100):
torch.cuda.synchronize()
start = time.time()
for _ in range(n):
_ = fn(x)
torch.cuda.synchronize()
return (time.time() - start) / n * 1000
 
print(f"LayerNorm: {bench(ln):.2f} ms")
print(f"RMSNorm: {bench(rms):.2f} ms")
print(f"Speedup: {bench(ln) / bench(rms):.2f}x")
RMSNorm + LayerNorm comparison — production benchmark
✅ Ders 10.1 Özeti — Normalization
Transformer training stability'sinin temel taşı normalization. LayerNorm (Ba 2016): feature dim normalize, mean centering + variance scaling, learnable gain + bias. RMSNorm (Zhang 2019): mean centering YOK, sadece root mean square + gain. %30 faster, parameter savings. Llama-3 choice: RMSNorm. Pre-LN vs Post-LN: original Vaswani Post-LN deep transformer'da instable (gradient explosion). Pre-LN modern standard: gradient flow stable, learning rate warmup gerekmez, 100+ layer training mümkün. Llama-3 architecture: RMSNorm + Pre-LN — modern standard combo. Ders 10.2'de SwiGLU activation function'a geçeceğiz.

Sıradaki Ders: SwiGLU Activation Function#

Ders 10.2: SiLU + GLU = SwiGLU (Shazeer 2020), niye ReLU/GeLU yerine, Llama-3 implementation, FFN dimensions, performance.

Sık Sorulan Sorular

Genelde evet — modern modellerde RMSNorm tercih. AMA: küçük modeller veya specific tasks'ta LayerNorm'un mean centering avantajı olabilir. Üretim: Llama/Mistral RMSNorm, BERT/T5 LayerNorm.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler