İçeriğe geç

DPO Implementation From Scratch: TRL Source-Code Olmadan Tek Sayfa Code

TRL DPOTrainer kullanmadan kendi DPO kayıp fonksiyonunu yaz: log-probabilities computation, reference model handling, loss formula, gradient backprop. ~80 satır PyTorch. Hata yaparsan nerede yapıldığını anlamak için. Cookbook'un derinlemesine implementation dersi.

Şükrü Yusuf KAYA
30 dakikalık okuma
İleri
DPO Implementation From Scratch: TRL Source-Code Olmadan Tek Sayfa Code
python
# === DPO from scratch — TRL kullanmadan ===
import torch
import torch.nn.functional as F
 
def compute_logps(model, tokenizer, prompts, responses, device="cuda"):
"""Token-level log-probabilities of responses given prompts."""
all_logps = []
for prompt, response in zip(prompts, responses):
full_text = prompt + response
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
full_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(device)
 
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits = model(full_ids).logits # [1, seq, vocab]
 
# Shift for next-token prediction
logits = logits[:, :-1, :] # predict next from prev
labels = full_ids[:, 1:]
 
# Mask: only loss on response tokens (after prompt)
prompt_len = prompt_ids.size(1)
mask = torch.zeros_like(labels, dtype=torch.bool)
mask[:, prompt_len-1:] = True # response starts here
 
# Compute log-prob of each label token
log_probs = F.log_softmax(logits, dim=-1)
token_logps = torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
token_logps = token_logps * mask # mask out prompt
 
# Sum over response tokens
seq_logp = token_logps.sum(dim=-1) # [1]
all_logps.append(seq_logp)
return torch.cat(all_logps)
 
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps, beta=0.1):
"""DPO loss — sigmoid form."""
# Log ratios
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = ref_chosen_logps - ref_rejected_logps
 
# DPO logits
logits = beta * (pi_logratios - ref_logratios)
 
# Negative log-sigmoid
loss = -F.logsigmoid(logits).mean()
 
# Reward proxies for logging
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps).detach()
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps).detach()
reward_acc = (chosen_rewards > rejected_rewards).float().mean()
 
return loss, {"loss": loss.item(),
"reward_acc": reward_acc.item(),
"chosen_reward": chosen_rewards.mean().item(),
"rejected_reward": rejected_rewards.mean().item()}
 
# === Training loop ===
ref_model.eval()
for param in ref_model.parameters():
param.requires_grad = False
 
for batch in loader:
prompts, chosens, rejecteds = batch["prompt"], batch["chosen"], batch["rejected"]
 
# Policy log-probs
policy_c_logps = compute_logps(policy_model, tok, prompts, chosens)
policy_r_logps = compute_logps(policy_model, tok, prompts, rejecteds)
 
# Reference log-probs (no grad)
with torch.no_grad():
ref_c_logps = compute_logps(ref_model, tok, prompts, chosens)
ref_r_logps = compute_logps(ref_model, tok, prompts, rejecteds)
 
# Loss
loss, metrics = dpo_loss(policy_c_logps, policy_r_logps,
ref_c_logps, ref_r_logps, beta=0.1)
loss.backward()
optimizer.step()
optimizer.zero_grad()
 
print(f"loss={metrics['loss']:.4f} rwd_acc={metrics['reward_acc']:.3f}")
DPO from scratch — ~80 satır
✅ Teslim
  1. Yukarıdaki kodu küçük model'e (örn. SmolLM3 1.7B) uygula. 2) TRL DPOTrainer ile aynı sonucu mu veriyor — karşılaştır. 3) Sonraki ders: 11.4 — ORPO Single-Stage.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler