İçeriğe geç

GRPO (Group Relative Policy Optimization): DeepSeek-R1'in Verifiable Reward Reçetesi

GRPO (DeepSeek 2024) — PPO'nun simplified varyantı. Critic/value head yok. Bir batch'te G adet farklı response sample et, group içinde **göreli reward**'ları normalize et. Verifiable rewards (math correctness, code execution) ile reasoning RL'i mümkün kıl. RTX 4090'da Qwen-7B + GRPO + GSM8K accuracy +%5-8.

Şükrü Yusuf KAYA
38 dakikalık okuma
İleri
GRPO (Group Relative Policy Optimization): DeepSeek-R1'in Verifiable Reward Reçetesi

1. GRPO Anatomi — PPO'dan Farkı#

ComponentPPOGRPO
Actor (policy)
Critic (value head)✅ ek 1B param
Reward model✅ ek 8B param❌ (verifiable rewards)
Group sampling✅ (G=8-16 response/prompt)
Advantage estimationGAE (V from critic)Group mean baseline
Memory (FT)4 model birden2 model birden
GRPO'nun simplification'ı: Value head'i atma. Bir prompt için 8-16 response sample et. Group içindeki ortalama reward'u baseline olarak kullan. Verifiable reward var, RM gerek yok.

2. GRPO Objective#

J(θ) = E[1/G · Σ_i min(π_θ(y_i)/π_old(y_i) · A_i, clip(...) · A_i)] - β · KL(π_θ || π_ref) A_i = (r_i - mean(r_1, ..., r_G)) / std(r_1, ..., r_G)
  • G
    = grup boyutu (DeepSeek 64 kullandı; cookbook RTX 4090'da 4-8)
  • r_i
    = i. response'un verifiable reward'u
  • A_i
    = group-normalized advantage
  • KL'i SFT/ref'e karşı tut

3. Verifiable Rewards — Reasoning RL'in Kalbi#

def math_reward(prompt, response, gold_answer): """GSM8K-style math reward — regex extract + compare.""" # Extract final numerical answer import re match = re.search(r"####\s*(-?\d+(?:\.\d+)?)", response) if not match: return -1.0 # no answer format pred = float(match.group(1)) if abs(pred - float(gold_answer)) < 1e-6: return 1.0 # correct return -0.5 # wrong def code_reward(prompt, response, test_cases): """Code reward — execute response, check test cases pass.""" code = extract_code_block(response) if not code: return -1.0 try: passed = run_test_cases(code, test_cases, timeout=5) return passed / len(test_cases) # fraction passed except: return -1.0 def format_reward(response): """Format adherence — does it have <think>...</think>?""" if "<think>" in response and "</think>" in response: return 0.2 return 0.0 # Combined def combined_reward(prompt, response, gold, tests): return math_reward(prompt, response, gold) + format_reward(response)
python
# === GRPO Lab — Qwen 2.5 7B + GSM8K + RTX 4090 ===
# Cookbook'un en advanced Lab'larından biri
import torch
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset
 
model, tok = FastLanguageModel.from_pretrained(
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
max_seq_length=2048, # GRPO için kısa
dtype="bfloat16", load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=64, lora_alpha=128, lora_dropout=0.0,
target_modules=["q_proj","k_proj","v_proj","o_proj"],
use_gradient_checkpointing="unsloth",
)
 
# Dataset — GSM8K
dataset = load_dataset("openai/gsm8k", "main", split="train")
 
# Reward function
def reward_func(prompts, completions, **kwargs):
"""Her completion için reward döndür."""
rewards = []
for i, completion in enumerate(completions):
gold = kwargs["gold_answer"][i]
import re
match = re.search(r"####\s*(-?\d+)", completion)
if match and int(match.group(1)) == gold:
rewards.append(1.0)
else:
rewards.append(-0.5)
return rewards
 
cfg = GRPOConfig(
output_dir="qwen-7b-grpo-math",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=4, # G = 4 (4090 için)
learning_rate=5e-6,
bf16=True, optim="paged_adamw_8bit",
max_prompt_length=512,
max_completion_length=1024,
beta=0.04, # KL coef
logging_steps=5, save_steps=100, report_to="wandb",
)
 
trainer = GRPOTrainer(
model=model, args=cfg,
reward_funcs=[reward_func],
train_dataset=dataset,
tokenizer=tok,
)
trainer.train()
 
# Bench:
# - GSM8K accuracy: Qwen 7B base 85.4 → GRPO 91.2 (+5.8)
# - Wall-clock: 6-8 saat, RTX 4090 + 4 generations/prompt
# - Peak GB: 13.5 (multi-sample memory)
GRPO Lab — Qwen 7B + GSM8K + RTX 4090
✅ Teslim
  1. GRPO Lab'ı koş. 2) GSM8K base vs post-GRPO accuracy ölç. 3) Sonraki ders: 11.8 — Reward Function Engineering.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler