FSDP2 (fully_shard): Per-Parameter Sharding + DTensor + 2024+ PyTorch Innovation
FSDP2 (PyTorch 2.4+) — evolution of FSDP. Per-parameter sharding (FlatParameter pattern dropped), DTensor backbone, FQN-based resumable checkpointing, easier mixed precision. Llama 3.3 70B + FSDP2 + DCP recipe.
Şükrü Yusuf KAYA
30 min read
Advanced1. FSDP1 → FSDP2 Geçiş#
| Aspect | FSDP1 (legacy) | FSDP2 |
|---|---|---|
| Sharding unit | FlatParameter — bir layer'ın tüm param'ları flatlenir | DTensor — her param ayrı |
| API | FullyShardedDataParallel(model, ...) | fully_shard(model, ...) |
| Mixed precision | manual policy | per-module override rahat |
| State dict | full/sharded/local — birkaç mod | DCP (Distributed Checkpoint) standardı |
| Resumable checkpoint | manuel zor | FQN'larla yerel |
| use_orig_params | optional | always (default) |
| 2D parallelism (TP + FSDP) | ekstra zor | native |
FSDP2'nin pratik kazanımları:
- DTensor: tensor manipulation'lar ile PT-style API
- Per-param: LoRA + FSDP uyumluluğu daha rahat (gradient masking sorunsuz)
- DCP: multi-rank checkpoint save/load — atomic, parallel, FQN-mapped
python
# === FSDP2 reçetesi ===import torchfrom torch.distributed.fsdp import fully_shard, MixedPrecisionPolicyfrom torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dictimport torch.distributed.checkpoint as dcp # Distributed initdist.init_process_group(backend="nccl") # Modelmodel = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.3-70B-Instruct", torch_dtype=torch.bfloat16,) # Mixed precision per-modulemp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, output_dtype=torch.bfloat16,) # Wrap her decoder layer'ı ayrıfor layer in model.model.layers: fully_shard(layer, mp_policy=mp_policy)# Root model'i de wrapfully_shard(model, mp_policy=mp_policy) # === DCP Checkpoint save ===state_dict = get_state_dict(model, optimizer)dcp.save(state_dict, checkpoint_id="ckpt/step-1000/")# /ckpt/step-1000/ klasörü her rank parça parça yazar (parallel I/O) # === Resume — start_rank=4 ile yeniden başlatstate_dict = get_state_dict(model, optimizer)dcp.load(state_dict, checkpoint_id="ckpt/step-1000/")set_state_dict(model, optimizer, model_state_dict=state_dict)FSDP2 + DCP checkpoint
✅ Teslim
- PyTorch ≥ 2.4 ile FSDP2 deneyimle. 2) Eski FSDP1 vs FSDP2 throughput karşılaştır. 3) Sonraki ders: 4.3 — DeepSpeed ZeRO Stage 1/2/3.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Part 0 — Engineering Foundations
Welcome to the Fine-Tuning Cookbook: System, Stage Taxonomy, and the Reproducibility Contract
Start LearningPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags, and Deterministic CUDA — End the 'Works on My Machine' Problem
Start LearningPart 0 — Engineering Foundations