Skip to content

FSDP2 (fully_shard): Per-Parameter Sharding + DTensor + 2024+ PyTorch Innovation

FSDP2 (PyTorch 2.4+) — evolution of FSDP. Per-parameter sharding (FlatParameter pattern dropped), DTensor backbone, FQN-based resumable checkpointing, easier mixed precision. Llama 3.3 70B + FSDP2 + DCP recipe.

Şükrü Yusuf KAYA
30 min read
Advanced
FSDP2 (fully_shard): Per-Parameter Sharding + DTensor + 2024+ PyTorch Yeniliği

1. FSDP1 → FSDP2 Geçiş#

AspectFSDP1 (legacy)FSDP2
Sharding unitFlatParameter — bir layer'ın tüm param'ları flatlenirDTensor — her param ayrı
API
FullyShardedDataParallel(model, ...)
fully_shard(model, ...)
Mixed precisionmanual policyper-module override rahat
State dictfull/sharded/local — birkaç modDCP (Distributed Checkpoint) standardı
Resumable checkpointmanuel zorFQN'larla yerel
use_orig_paramsoptionalalways (default)
2D parallelism (TP + FSDP)ekstra zornative
FSDP2'nin pratik kazanımları:
  1. DTensor: tensor manipulation'lar ile PT-style API
  2. Per-param: LoRA + FSDP uyumluluğu daha rahat (gradient masking sorunsuz)
  3. DCP: multi-rank checkpoint save/load — atomic, parallel, FQN-mapped
python
# === FSDP2 reçetesi ===
import torch
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.distributed.checkpoint as dcp
 
# Distributed init
dist.init_process_group(backend="nccl")
 
# Model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.3-70B-Instruct",
torch_dtype=torch.bfloat16,
)
 
# Mixed precision per-module
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
)
 
# Wrap her decoder layer'ı ayrı
for layer in model.model.layers:
fully_shard(layer, mp_policy=mp_policy)
# Root model'i de wrap
fully_shard(model, mp_policy=mp_policy)
 
# === DCP Checkpoint save ===
state_dict = get_state_dict(model, optimizer)
dcp.save(state_dict, checkpoint_id="ckpt/step-1000/")
# /ckpt/step-1000/ klasörü her rank parça parça yazar (parallel I/O)
 
# === Resume — start_rank=4 ile yeniden başlat
state_dict = get_state_dict(model, optimizer)
dcp.load(state_dict, checkpoint_id="ckpt/step-1000/")
set_state_dict(model, optimizer, model_state_dict=state_dict)
FSDP2 + DCP checkpoint
✅ Teslim
  1. PyTorch ≥ 2.4 ile FSDP2 deneyimle. 2) Eski FSDP1 vs FSDP2 throughput karşılaştır. 3) Sonraki ders: 4.3 — DeepSpeed ZeRO Stage 1/2/3.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content