DPO Implementation From Scratch: TRL Source-Code Olmadan Tek Sayfa Code
TRL DPOTrainer kullanmadan kendi DPO kayıp fonksiyonunu yaz: log-probabilities computation, reference model handling, loss formula, gradient backprop. ~80 satır PyTorch. Hata yaparsan nerede yapıldığını anlamak için. Cookbook'un derinlemesine implementation dersi.
Şükrü Yusuf KAYA
30 dakikalık okuma
İleripython
# === DPO from scratch — TRL kullanmadan ===import torchimport torch.nn.functional as F def compute_logps(model, tokenizer, prompts, responses, device="cuda"): """Token-level log-probabilities of responses given prompts.""" all_logps = [] for prompt, response in zip(prompts, responses): full_text = prompt + response prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) full_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(device) with torch.cuda.amp.autocast(dtype=torch.bfloat16): logits = model(full_ids).logits # [1, seq, vocab] # Shift for next-token prediction logits = logits[:, :-1, :] # predict next from prev labels = full_ids[:, 1:] # Mask: only loss on response tokens (after prompt) prompt_len = prompt_ids.size(1) mask = torch.zeros_like(labels, dtype=torch.bool) mask[:, prompt_len-1:] = True # response starts here # Compute log-prob of each label token log_probs = F.log_softmax(logits, dim=-1) token_logps = torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) token_logps = token_logps * mask # mask out prompt # Sum over response tokens seq_logp = token_logps.sum(dim=-1) # [1] all_logps.append(seq_logp) return torch.cat(all_logps) def dpo_loss(policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=0.1): """DPO loss — sigmoid form.""" # Log ratios pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = ref_chosen_logps - ref_rejected_logps # DPO logits logits = beta * (pi_logratios - ref_logratios) # Negative log-sigmoid loss = -F.logsigmoid(logits).mean() # Reward proxies for logging chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps).detach() rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps).detach() reward_acc = (chosen_rewards > rejected_rewards).float().mean() return loss, {"loss": loss.item(), "reward_acc": reward_acc.item(), "chosen_reward": chosen_rewards.mean().item(), "rejected_reward": rejected_rewards.mean().item()} # === Training loop ===ref_model.eval()for param in ref_model.parameters(): param.requires_grad = False for batch in loader: prompts, chosens, rejecteds = batch["prompt"], batch["chosen"], batch["rejected"] # Policy log-probs policy_c_logps = compute_logps(policy_model, tok, prompts, chosens) policy_r_logps = compute_logps(policy_model, tok, prompts, rejecteds) # Reference log-probs (no grad) with torch.no_grad(): ref_c_logps = compute_logps(ref_model, tok, prompts, chosens) ref_r_logps = compute_logps(ref_model, tok, prompts, rejecteds) # Loss loss, metrics = dpo_loss(policy_c_logps, policy_r_logps, ref_c_logps, ref_r_logps, beta=0.1) loss.backward() optimizer.step() optimizer.zero_grad() print(f"loss={metrics['loss']:.4f} rwd_acc={metrics['reward_acc']:.3f}")DPO from scratch — ~80 satır
✅ Teslim
- Yukarıdaki kodu küçük model'e (örn. SmolLM3 1.7B) uygula. 2) TRL DPOTrainer ile aynı sonucu mu veriyor — karşılaştır. 3) Sonraki ders: 11.4 — ORPO Single-Stage.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
İlgili İçerikler
Part 0 — Engineering Foundations
Fine-Tuning Cookbook'a Hoş Geldin: Sistematik, Stage Taksonomisi ve Reproducibility Kontratı
Öğrenmeye BaşlaPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags ve Deterministic CUDA — 'Sende Niye Çalışıyor Bende Çalışmıyor' Sorununu Bitir
Öğrenmeye BaşlaPart 0 — Engineering Foundations