PyTorch FSDP Anatomy: FULL_SHARD vs SHARD_GRAD_OP vs HYBRID_SHARD + Mixed Precision Policy
FSDP — modern PyTorch's distributed training weapon. 3 sharding strategies, MixedPrecision policy, BackwardPrefetch, auto_wrap_policy. Llama 3.3 70B QLoRA recipe on 8×H100 SXM.
Şükrü Yusuf KAYA
38 min read
Advanced🎯 Niye FSDP?
70B+ model RTX 4090'a (24GB) sığmıyor. Tek H100 80GB'da bile train etmek imkansız (sadece W+G+O = 280 GB AdamW ile). Çözüm: sharding — model'in her parametre, gradient, optimizer state'ini N GPU arasında böl. FSDP = PyTorch'un native, en olgun sharding aracı.
1. 3 Sharding Strategy#
FULL_SHARD (ZeRO-3 ekvivalenti): Her GPU sadece W/N + G/N + O/N tutar Forward'da W'yi all-gather, sonra free Backward'da W'yi all-gather, grad'i reduce-scatter SHARD_GRAD_OP (ZeRO-2 ekvivalenti): Her GPU full W tutar (replicated) G ve O sharded Memory: W + G/N + O/N HYBRID_SHARD: Intra-node: FULL_SHARD (örn. 8 GPU NVLink ile) Inter-node: DDP (4 node) Best of both: NVLink hızını kullan, node-arası bandwidth tasarrufu
| Strategy | Memory | Communication | Use case |
|---|---|---|---|
| DDP (baseline) | W + G + O | Reduce(G) only | küçük model, çok GPU |
| SHARD_GRAD_OP | W + G/N + O/N | Reduce-scatter(G), gather not needed | mid-size |
| FULL_SHARD | W/N + G/N + O/N | All-gather(W) + reduce-scatter(G) | büyük model |
| HYBRID_SHARD | W/local + G/N + O/N | NVLink all-gather, internet reduce | multi-node |
Llama 3.3 70B AdamW karşılaştırma (8 GPU):
- DDP: 280 GB / GPU → 8 H100 80GB yetmez
- SHARD_GRAD_OP: 140 GB / GPU → yine yetmez
- FULL_SHARD: 35 GB / GPU → rahat sığar ✅
python
# === FSDP + Llama 3.3 70B Full FT (8×H100 SXM) ===import torchimport torch.distributed as distfrom torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, BackwardPrefetch, CPUOffload,)from torch.distributed.fsdp.wrap import transformer_auto_wrap_policyfrom transformers import AutoModelForCausalLM, AutoTokenizerfrom transformers.models.llama.modeling_llama import LlamaDecoderLayer # 1. Init distributeddist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # 2. Load model on each GPU (her GPU önce full model'i yükler, sonra shard'lanır)model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.3-70B-Instruct", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",) # 3. Mixed precision policybf16_policy = MixedPrecision( param_dtype=torch.bfloat16, # param compute'ta bf16 reduce_dtype=torch.float32, # gradient reduce fp32 (numerik kararlılık) buffer_dtype=torch.bfloat16, # buffer'lar bf16) # 4. Auto-wrap policy — transformer layer-wiseauto_wrap = transformer_auto_wrap_policy( transformer_layer_cls={LlamaDecoderLayer},) # 5. FSDP wrapmodel = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=bf16_policy, auto_wrap_policy=auto_wrap, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # next param prefetch use_orig_params=True, # PyTorch 2.0+ recommended device_id=torch.cuda.current_device(), sync_module_states=True, # init sync across ranks cpu_offload=CPUOffload(offload_params=False), # disable for speed) # 6. Optimizer — FSDP-awareoptimizer = torch.optim.AdamW( model.parameters(), lr=5e-6, # full FT için DÜŞÜK lr (LoRA 2e-4'ten) weight_decay=0.01,) # 7. Training loopfor batch in loader: out = model(**batch) out.loss.backward() optimizer.step() optimizer.zero_grad() # 8. Save FSDP state — full state_dict consolidatefrom torch.distributed.fsdp import FullStateDictConfig, StateDictTypesave_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): if dist.get_rank() == 0: torch.save(model.state_dict(), "llama-3.3-70b-tr-finetuned.pt")FSDP + Llama 3.3 70B Full FT — 8×H100 SXM
2. FSDP Tuning İpuçları#
| Parameter | Önerilen | Etki |
|---|---|---|
sharding_strategy | FULL_SHARD (70B+) / HYBRID_SHARD (multi-node) | memory vs comm |
mixed_precision.reduce_dtype | fp32 | numerik kararlılık (bf16 reduce'ta loss spike riski) |
backward_prefetch | BACKWARD_PRE | next layer param prefetch, throughput +%5-10 |
use_orig_params | True | LoRA + FSDP uyumluluk |
auto_wrap_policy | transformer_auto_wrap | layer-wise shard (manual wrap'a göre rahat) |
forward_prefetch | True (2.4+) | forward'da next param early-gather |
limit_all_gathers | True | memory peak'i kontrol et |
sync_module_states | True (init time) | rank-0'dan diğer rank'lere broadcast |
8×H100 SXM + Llama 3.3 70B FSDP bench:
- Full FT: 7.2 saat 1 epoch (50K instruction)
- QLoRA: 5.6 saat 1 epoch (NF4 + FSDP)
- Step/s: 1.3 (full), 1.8 (QLoRA)
- Peak memory/GPU: 42 GB (full FT), 28 GB (QLoRA)
🐛 FMD — 'FSDP run hangs at first step'
Hipotezler: (a) `sync_module_states=False` → her rank farklı weights → ilk reduce hang. Çözüm: True. (b) `use_orig_params=False` + PEFT (LoRA) → param mapping karışıklığı → deadlock. Çözüm: True. (c) NCCL timeout (default 30 dk) → `os.environ['NCCL_TIMEOUT']='3600'`. (d) Auto-wrap policy yanlış (örn. tüm modeli tek FSDP unit) → forward'da OOM. Çözüm: transformer_auto_wrap_policy. Drill: NCCL_DEBUG=INFO ile run, hang point'i bul.
✅ Teslim
- 2× GPU lokal (eğer varsa) veya cloud 8×H100 dev cluster'da FSDP toy run (200M model, 100 step). 2) Memory peak'i her rank'te ölç. 3) Sonraki ders: 4.2 — FSDP2 Per-Parameter Sharding.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
Related Content
Part 0 — Engineering Foundations
Welcome to the Fine-Tuning Cookbook: System, Stage Taxonomy, and the Reproducibility Contract
Start LearningPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags, and Deterministic CUDA — End the 'Works on My Machine' Problem
Start LearningPart 0 — Engineering Foundations