İçeriğe geç

PyTorch FSDP Anatomi: FULL_SHARD vs SHARD_GRAD_OP vs HYBRID_SHARD + Mixed Precision Policy

FSDP — modern PyTorch'un distributed training silahı. 3 sharding strategy (FULL_SHARD param+grad+optim sharded, SHARD_GRAD_OP yalnız grad+optim, HYBRID_SHARD intra-node FSDP + inter-node DDP), MixedPrecision policy (param/reduce/buffer dtype'ları), BackwardPrefetch, auto_wrap_policy (transformer layer-wise). 8×H100 SXM'de Llama 3.3 70B QLoRA tam reçete.

Şükrü Yusuf KAYA
38 dakikalık okuma
İleri
PyTorch FSDP Anatomi: FULL_SHARD vs SHARD_GRAD_OP vs HYBRID_SHARD + Mixed Precision Policy
🎯 Niye FSDP?
70B+ model RTX 4090'a (24GB) sığmıyor. Tek H100 80GB'da bile train etmek imkansız (sadece W+G+O = 280 GB AdamW ile). Çözüm: sharding — model'in her parametre, gradient, optimizer state'ini N GPU arasında böl. FSDP = PyTorch'un native, en olgun sharding aracı.

1. 3 Sharding Strategy#

FULL_SHARD (ZeRO-3 ekvivalenti): Her GPU sadece W/N + G/N + O/N tutar Forward'da W'yi all-gather, sonra free Backward'da W'yi all-gather, grad'i reduce-scatter SHARD_GRAD_OP (ZeRO-2 ekvivalenti): Her GPU full W tutar (replicated) G ve O sharded Memory: W + G/N + O/N HYBRID_SHARD: Intra-node: FULL_SHARD (örn. 8 GPU NVLink ile) Inter-node: DDP (4 node) Best of both: NVLink hızını kullan, node-arası bandwidth tasarrufu
StrategyMemoryCommunicationUse case
DDP (baseline)W + G + OReduce(G) onlyküçük model, çok GPU
SHARD_GRAD_OPW + G/N + O/NReduce-scatter(G), gather not neededmid-size
FULL_SHARDW/N + G/N + O/NAll-gather(W) + reduce-scatter(G)büyük model
HYBRID_SHARDW/local + G/N + O/NNVLink all-gather, internet reducemulti-node
Llama 3.3 70B AdamW karşılaştırma (8 GPU):
  • DDP: 280 GB / GPU → 8 H100 80GB yetmez
  • SHARD_GRAD_OP: 140 GB / GPU → yine yetmez
  • FULL_SHARD: 35 GB / GPU → rahat sığar
python
# === FSDP + Llama 3.3 70B Full FT (8×H100 SXM) ===
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
BackwardPrefetch,
CPUOffload,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
# 1. Init distributed
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
 
# 2. Load model on each GPU (her GPU önce full model'i yükler, sonra shard'lanır)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.3-70B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
 
# 3. Mixed precision policy
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16, # param compute'ta bf16
reduce_dtype=torch.float32, # gradient reduce fp32 (numerik kararlılık)
buffer_dtype=torch.bfloat16, # buffer'lar bf16
)
 
# 4. Auto-wrap policy — transformer layer-wise
auto_wrap = transformer_auto_wrap_policy(
transformer_layer_cls={LlamaDecoderLayer},
)
 
# 5. FSDP wrap
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=bf16_policy,
auto_wrap_policy=auto_wrap,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # next param prefetch
use_orig_params=True, # PyTorch 2.0+ recommended
device_id=torch.cuda.current_device(),
sync_module_states=True, # init sync across ranks
cpu_offload=CPUOffload(offload_params=False), # disable for speed
)
 
# 6. Optimizer — FSDP-aware
optimizer = torch.optim.AdamW(
model.parameters(),
lr=5e-6, # full FT için DÜŞÜK lr (LoRA 2e-4'ten)
weight_decay=0.01,
)
 
# 7. Training loop
for batch in loader:
out = model(**batch)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
 
# 8. Save FSDP state — full state_dict consolidate
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
if dist.get_rank() == 0:
torch.save(model.state_dict(), "llama-3.3-70b-tr-finetuned.pt")
FSDP + Llama 3.3 70B Full FT — 8×H100 SXM

2. FSDP Tuning İpuçları#

ParameterÖnerilenEtki
sharding_strategy
FULL_SHARD (70B+) / HYBRID_SHARD (multi-node)memory vs comm
mixed_precision.reduce_dtype
fp32numerik kararlılık (bf16 reduce'ta loss spike riski)
backward_prefetch
BACKWARD_PREnext layer param prefetch, throughput +%5-10
use_orig_params
TrueLoRA + FSDP uyumluluk
auto_wrap_policy
transformer_auto_wraplayer-wise shard (manual wrap'a göre rahat)
forward_prefetch
True (2.4+)forward'da next param early-gather
limit_all_gathers
Truememory peak'i kontrol et
sync_module_states
True (init time)rank-0'dan diğer rank'lere broadcast
8×H100 SXM + Llama 3.3 70B FSDP bench:
  • Full FT: 7.2 saat 1 epoch (50K instruction)
  • QLoRA: 5.6 saat 1 epoch (NF4 + FSDP)
  • Step/s: 1.3 (full), 1.8 (QLoRA)
  • Peak memory/GPU: 42 GB (full FT), 28 GB (QLoRA)
🐛 FMD — 'FSDP run hangs at first step'
Hipotezler: (a) `sync_module_states=False` → her rank farklı weights → ilk reduce hang. Çözüm: True. (b) `use_orig_params=False` + PEFT (LoRA) → param mapping karışıklığı → deadlock. Çözüm: True. (c) NCCL timeout (default 30 dk) → `os.environ['NCCL_TIMEOUT']='3600'`. (d) Auto-wrap policy yanlış (örn. tüm modeli tek FSDP unit) → forward'da OOM. Çözüm: transformer_auto_wrap_policy. Drill: NCCL_DEBUG=INFO ile run, hang point'i bul.
✅ Teslim
  1. 2× GPU lokal (eğer varsa) veya cloud 8×H100 dev cluster'da FSDP toy run (200M model, 100 step). 2) Memory peak'i her rank'te ölç. 3) Sonraki ders: 4.2 — FSDP2 Per-Parameter Sharding.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler