Skip to content

FSDP + ZeRO: Sharded Training — Memory Revolution from Rajbhandari 2020 to Llama-3

ZeRO (Zero Redundancy Optimizer, Rajbhandari 2020) — DeepSpeed library: optimizer state, gradients, parameters sharding stages 1/2/3. FSDP (Fully Sharded Data Parallel, PyTorch native) — ZeRO-3 implementation. Llama-3 production: FSDP + activation checkpointing. Memory math: 8B model trainable on 1 H100.

Şükrü Yusuf KAYA
75 min read
Advanced
FSDP + ZeRO: Sharded Training — Rajbhandari 2020'den Llama-3'e Memory Devrim
🧩 FSDP/ZeRO — 8B modeli tek H100'de eğitebilir hale getirdi
DDP'nin sınırı memory: her GPU tam model copy → 8B model 130 GB needs → tek H100 (80GB) yetmez. Çözüm: sharding. Rajbhandari 2020 ZeRO paper'ı: optimizer state (stage 1), gradients (stage 2), parameters (stage 3) — hepsini GPU'lar arasında shard et. 8B model 16 GPU'da 130/16 = 8 GB per GPU. Llama-3 8B → tek H100'de eğitilebilir. PyTorch FSDP, DeepSpeed ZeRO-3 — production fiili standart. 75 dakika sonra: ZeRO stage'lerinin matematiksel anatomisini, FSDP wrapping mekanizmasını, activation checkpointing'i, Llama-3 production setup'ını kavramış olacaksın.

Ders Haritası (10 Bölüm)#

  1. DDP'nin memory bottleneck'i — gradient + optimizer state replication
  2. ZeRO stages 1, 2, 3 — incremental sharding
  3. ZeRO-1: optimizer state shard — 4x memory saving
  4. ZeRO-2: gradients shard — 8x memory saving
  5. ZeRO-3: parameters shard — full sharding
  6. FSDP — ZeRO-3 PyTorch native implementation
  7. All-gather + reduce-scatter — communication primitives
  8. Activation checkpointing — memory-compute trade-off
  9. Llama-3 production setup — FSDP + AC + bf16
  10. DeepSpeed vs FSDP — choice criteria

2-5. ZeRO Stages#

2.1 Memory breakdown#

AdamW training, K = world size (number of GPUs):
DDP per GPU memory:
  • Params: P (full copy)
  • Gradients: P (full copy)
  • Optimizer state (AdamW): 2P (m, v)
  • Total: 4P per GPU
For Llama-3-8B (P=8B):
  • DDP: 4 × 8B × 4 byte = 128 GB per GPU — yetmez!

2.2 ZeRO Stage 1: Optimizer State Sharding#

Optimizer state (m, v) AdamW = 2P each — sharded across GPUs.
Per GPU optimizer state: 2P/K
DDP: 4P → ZeRO-1: P + P + 2P/K
K=16: 130 → 32+2 = ~34 GB per GPU. Already trainable.

2.3 ZeRO Stage 2: Gradient Sharding#

Gradient'ler also sharded:
Per GPU memory: P (params) + P/K (gradients) + 2P/K (optimizer)
K=16: 8 + 0.5 + 1 = ~10 GB per GPU.

2.4 ZeRO Stage 3 = FSDP: Full Sharding#

Parametreler de sharded! Her GPU only 1/K of params.
Per GPU memory: P/K (params) + P/K (gradients) + 2P/K (optimizer) = 4P/K
K=16, Llama-3-8B: 4 × 8B / 16 × 4 byte = 8 GB per GPU.

2.5 ZeRO-3 communication#

Forward pass: each layer needs full params → all-gather from all GPUs. Backward: gradient → reduce-scatter (each GPU gets its shard).
More communication than DDP (forward extra all-gathers). Trade-off: memory vs compute.

2.6 Activation memory#

Full sharding doesn't solve activation memory. Llama-3-8B 32 layers × 2048 seq × 4096 dim × 2 byte = 1 GB per layer activations.
Solution: activation checkpointing — sadece her N layer'da save, recompute backward'da.

2.7 Hybrid: ZeRO + DDP#

Real world: outer DDP + inner ZeRO. Best of both.

6-9. FSDP PyTorch Production#

6.1 FSDP PyTorch native#

PyTorch 1.11+ FSDP — ZeRO-3 implementation.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP model = MyModel() model = FSDP(model)
Auto-shard parameters, gradients, optimizer state.

6.2 Auto-wrap policy#

FSDP nested wrapping:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy import functools auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={LlamaBlock}, # Wrap each block separately ) model = FSDP( model, auto_wrap_policy=auto_wrap_policy, mixed_precision=mp_policy, sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Optimization )

6.3 Sharding strategies#

  • FULL_SHARD: ZeRO-3 (params + grads + optim state)
  • SHARD_GRAD_OP: ZeRO-2 (grads + optim state)
  • NO_SHARD: DDP equivalent
  • HYBRID_SHARD: shard within node, replicate across nodes

6.4 Activation checkpointing#

from torch.utils.checkpoint import checkpoint_wrapper from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, CheckpointImpl, ) apply_activation_checkpointing( model, checkpoint_wrapper_fn=lambda submodule: checkpoint_wrapper(submodule, checkpoint_impl=CheckpointImpl.NO_REENTRANT), check_fn=lambda submodule: isinstance(submodule, LlamaBlock), )
Each block recomputes activation in backward → 30% slower forward+backward, but %50+ activation memory savings.

6.5 Llama-3-8B FSDP setup (production)#

# Mixed precision policy mp_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) # FSDP wrap model = FSDP( model, auto_wrap_policy=auto_wrap_policy, mixed_precision=mp_policy, sharding_strategy=ShardingStrategy.FULL_SHARD, cpu_offload=None, # Don't offload to CPU backward_prefetch=BackwardPrefetch.BACKWARD_PRE, use_orig_params=True, # PyTorch 2.0+ optimization ) # Activation checkpointing apply_activation_checkpointing(model, check_fn=lambda m: isinstance(m, LlamaBlock))

6.6 Memory profile#

Llama-3-8B on H100 (80GB):
  • Params (FSDP shard, K=8): 1 GB
  • Gradients (FSDP shard): 1 GB
  • Optimizer state (FSDP shard): 2 GB
  • Activations (AC enabled): 5 GB
  • Buffer + overhead: 1 GB
  • Total: ~10 GB per GPU
Utilize remaining 70 GB → larger batch size → faster training.

6.7 DeepSpeed alternative#

DeepSpeed library (Microsoft): ZeRO original implementation. Now FSDP roughly comparable. DeepSpeed advantage: more features (ZeRO-Infinity offload to NVMe), config files, mature ecosystem.
PyTorch FSDP advantage: native, simpler integration, PyTorch versioning. Modern preference: FSDP unless specific DeepSpeed feature needed.
✅ Ders 13.2 Özeti — FSDP + ZeRO
ZeRO (Rajbhandari 2020): sharding stages 1 (optim state), 2 (gradients), 3 (params + grads + optim). FSDP = ZeRO-3 PyTorch native. Memory math: K=16 GPU, 4P → 4P/K dramatically. Llama-3-8B FSDP + activation checkpointing: 10 GB per H100, tek node 8 GPU yetiyor. Hybrid: FSDP within node + DDP across nodes ('HYBRID_SHARD'). DeepSpeed vs FSDP: FSDP modern preference, DeepSpeed advanced features. Ders 13.3'te 3D parallelism (TP + PP + DP) ve Llama-3-70B+ training'e geçeceğiz.

Sıradaki Ders: 3D Parallelism#

Ders 13.3: Tensor Parallelism (TP) + Pipeline Parallelism (PP) + Data Parallelism (DP) — frontier model training (70B, 405B). Megatron-LM.

Frequently Asked Questions

Forward all-gather + backward reduce-scatter add bandwidth. With NVLink → %10-20 overhead. InfiniBand multi-node → %20-40 overhead. Memory savings worth the speed cost (8B model can't be trained without FSDP).

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content