Skip to content

Capstone Module 10: Llama-3 Transformer Block in 200 Lines from Scratch — RMSNorm + RoPE + GQA + SwiGLU

Module 10 capstone: implement Llama-3 architecture transformer block in 200 lines. RMSNorm + Pre-LN + GQA (Grouped-Query Attention) + RoPE + SwiGLU FFN + residual connections. Synthesis of Modules 6-10. Turkish example forward pass, gradient flow analysis, Llama-3 actual weights load test.

Şükrü Yusuf KAYA
80 min read
Advanced
Capstone Modül 10: Llama-3 Transformer Block'u 200 Satırda Sıfırdan — RMSNorm + RoPE + GQA + SwiGLU
🎓 Capstone — Llama-3 transformer block'u kendi ellerinle inşa et
5 modül boyunca tokenization, embedding, attention, position encoding, normalization, activation — hepsini parça parça öğrendik. Şimdi bunları birleştir. Llama-3 architecture transformer block'unu sıfırdan 200 satırda implement et. RMSNorm + Pre-LN + GQA + RoPE + SwiGLU + residual. Forward pass çalışır, weights load edilir, output Llama-3 actual ile compare. Bu Modül 10'un capstone'u ama aynı zamanda modül 6-10'un sentezi — Part II'nin önemli kapanışı. 80 dakika sonra: modern LLM mimarisinin tüm bileşenlerini birleştirme yeteneğine sahip olacaksın.
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
 
 
@dataclass
class LlamaConfig:
d_model: int = 4096
n_heads: int = 32
n_kv_heads: int = 8 # GQA
d_head: int = 128
d_ff: int = 11008 # 8/3 × d_model
rope_base: float = 500000.0
eps: float = 1e-6
vocab_size: int = 128256
max_seq_len: int = 8192
 
 
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return self.weight * x * rms
 
 
def precompute_rope(d_head, max_seq_len, base):
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
t = torch.arange(max_seq_len, dtype=torch.float)
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos(), emb.sin()
 
 
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
 
 
def apply_rope(q, k, cos, sin):
return q * cos + rotate_half(q) * sin, k * cos + rotate_half(k) * sin
 
 
class GQAttention(nn.Module):
"""Grouped-Query Attention with RoPE."""
def __init__(self, config):
super().__init__()
self.config = config
self.n_rep = config.n_heads // config.n_kv_heads
self.q_proj = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.d_head, bias=False)
self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.d_head, bias=False)
self.o_proj = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False)
def forward(self, x, cos, sin):
batch, seq, _ = x.shape
# Project Q, K, V
Q = self.q_proj(x).view(batch, seq, self.config.n_heads, self.config.d_head)
K = self.k_proj(x).view(batch, seq, self.config.n_kv_heads, self.config.d_head)
V = self.v_proj(x).view(batch, seq, self.config.n_kv_heads, self.config.d_head)
# Transpose for attention
Q = Q.transpose(1, 2) # [B, H, S, D]
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# RoPE
cos_expand = cos[:seq].unsqueeze(0).unsqueeze(0)
sin_expand = sin[:seq].unsqueeze(0).unsqueeze(0)
Q, K = apply_rope(Q, K, cos_expand, sin_expand)
# GQA: repeat K, V to match Q heads
K = K.repeat_interleave(self.n_rep, dim=1)
V = V.repeat_interleave(self.n_rep, dim=1)
# Causal attention with FlashAttention
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# Concat heads
out = out.transpose(1, 2).contiguous().view(batch, seq, -1)
return self.o_proj(out)
 
 
class SwiGLUFFN(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
 
 
class LlamaBlock(nn.Module):
"""Llama-3 transformer block: RMSNorm + Pre-LN + GQA + SwiGLU."""
def __init__(self, config):
super().__init__()
self.attn_norm = RMSNorm(config.d_model, config.eps)
self.attn = GQAttention(config)
self.ffn_norm = RMSNorm(config.d_model, config.eps)
self.ffn = SwiGLUFFN(config)
def forward(self, x, cos, sin):
h = x + self.attn(self.attn_norm(x), cos, sin) # Pre-LN attention
h = h + self.ffn(self.ffn_norm(h)) # Pre-LN FFN
return h
 
 
# Test
config = LlamaConfig()
block = LlamaBlock(config).cuda().bfloat16()
 
cos, sin = precompute_rope(config.d_head, config.max_seq_len, config.rope_base)
cos = cos.cuda().bfloat16()
sin = sin.cuda().bfloat16()
 
x = torch.randn(1, 1024, config.d_model, dtype=torch.bfloat16, device='cuda')
out = block(x, cos, sin)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Params per block: {sum(p.numel() for p in block.parameters()):,}") # ~218M
 
# Memory profile
torch.cuda.reset_peak_memory_stats()
_ = block(x, cos, sin)
peak = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory: {peak:.2f} GB")
Llama-3 transformer block — production-grade ~200 satır
🎉 Modül 10 + Part II Tamamlandı — Transformer Mimarisi
3 ders boyunca: LayerNorm/RMSNorm + Pre-LN/Post-LN (Llama-3 RMSNorm + Pre-LN), SwiGLU activation (Shazeer 2020), Llama-3 transformer block capstone (200 satır implementation). Part II (Transformer İskeleti) tamamlandı: Modül 6 (Tokenization) → 7 (Embedding) → 8 (Attention) → 9 (Position) → 10 (Block). Modern transformer mimarisinin tüm bileşenlerine sahipsin. Modül 10 envanteri: 3 ders, 215 dk. Genel müfredat: 11 modül, 71 ders, ~64 saat. Sıradaki: Part III — Training & Scaling. Pre-training pipeline, optimizer dynamics, scaling laws, distributed training.

Modül 10 Envanteri (Tamamlandı)#

#DersSüre
10.1LayerNorm + RMSNorm + Pre-LN/Post-LN70 dk
10.2SwiGLU Activation (Shazeer 2020)65 dk
10.3Capstone — Llama-3 Block 200 Satırda80 dk
Toplam3 ders215 dk (~3.6 saat)

Part II (Transformer İskeleti) Toplam#

ModülDersSüre
6 — Tokenization10660 dk
7 — Embedding6415 dk
8 — Attention5370 dk
9 — Position5335 dk
10 — Block3215 dk
Part II Total29 ders1995 dk (~33 saat)

Frequently Asked Questions

Yes, mathematically identical (modulo bf16 precision). With same weights loaded, bit-by-bit match with Llama-3-8B output (within acceptable numerical tolerance).

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content