İçeriğe geç

Process Reward Models (PRM): Step-Level Supervision — PRM800K Dataset

PRM = her reasoning step için ayrı reward. Outcome-only (final answer) yerine her ara adım kaliteyi öğretiyor. OpenAI PRM800K dataset, Math-Shepherd otomatik PRM generation, Step-DPO. Test-time tree search (Best-of-N, MCTS) için temel. RTX 4090'da PRM train + use.

Şükrü Yusuf KAYA
30 dakikalık okuma
İleri
Process Reward Models (PRM): Step-Level Supervision — PRM800K Dataset

1. Outcome Reward vs Process Reward#

Outcome Reward Model (ORM): reasoning_trace → final_answer → reward (binary correct/wrong) Process Reward Model (PRM): reasoning_trace step-by-step: step 1 → reward_1 (this step OK?) step 2 → reward_2 ... step N → reward_N (this step OK?)
PRM'in avantajı:
  • Erken hata yakalama
  • Step-level gradient → daha sample-efficient
  • Test-time: Best-of-N reranking, MCTS dahil tree search
Dezavantaj: Step-level annotation pahalı (OpenAI PRM800K'da 800K human-labeled step).

2. PRM Dataset Sources#

DatasetSizeMethodLisans
PRM800K (OpenAI)800K labeled stepsHuman annotatorMIT
Math-Shepherd80K problems × 16 tracesAuto-generated via Monte CarloApache
Step-DPO datavariesSelf-generated + filtervaries
Math-Shepherd auto-generation:
  • Her step için 10-20 continuation üret
  • Her continuation'ı bütün olarak grade et (correct/wrong)
  • Step'in "quality" değeri = correct continuation oranı
Bu yöntemle human-label gerek olmadan PRM data üretilebilir.
python
# === PRM Training — RTX 4090 ===
# PRM = step-level scorer, base model'in lm_head'i yerine regression head
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
 
class PRM(nn.Module):
def __init__(self, base_model_path):
super().__init__()
self.encoder = AutoModel.from_pretrained(base_model_path, torch_dtype="bfloat16")
self.scorer = nn.Linear(self.encoder.config.hidden_size, 1)
 
def forward(self, input_ids, step_positions):
"""step_positions: her step'in son token'ının index'i."""
h = self.encoder(input_ids).last_hidden_state # [b, seq, hidden]
step_h = h[:, step_positions] # [b, n_steps, hidden]
scores = self.scorer(step_h).squeeze(-1) # [b, n_steps]
return scores
 
# Training — Regression on step labels
model = PRM("Qwen/Qwen2.5-1.5B") # küçük base PRM için yeter
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
 
for batch in prm_dataset:
scores = model(batch["input_ids"], batch["step_positions"])
# MSE loss: each step has a target (probability of correct continuation)
loss = F.mse_loss(torch.sigmoid(scores), batch["step_labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
PRM training — step-level scorer

3. Test-Time Use — Best-of-N + MCTS#

def best_of_n_with_prm(model, prm, prompt, n=8): """Generate N traces, PRM ile rerank.""" traces = [model.generate(prompt) for _ in range(n)] scores = [] for trace in traces: step_scores = prm(trace) # Aggregation: min (worst step), mean, last (final) score = step_scores.min().item() scores.append(score) best_idx = max(range(n), key=lambda i: scores[i]) return traces[best_idx]
PRM ile Best-of-8: GSM8K accuracy +%3-5 (sadece sampling, no extra training).
✅ Teslim
  1. PRM800K'dan 10K sample download. 2) Qwen2.5-1.5B'yi PRM olarak train et. 3) Best-of-N test. 4) Sonraki ders: 11.10 — Constitutional AI + RLAIF.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler