Skip to content

DeepSeek-R1-Distill (Llama-8B / Qwen-7B): Reasoning Trace Distillation — Learning 'Think Tokens'

DeepSeek-R1-Distill — Llama/Qwen bases distilled from R1 (671B) traces. <think>...</think> format, CoT trace dataset, compressing R1's reasoning into 7-8B. Your own reasoning FT on RTX 4090: 1000 R1 traces suffice.

Şükrü Yusuf KAYA
35 min read
Advanced
DeepSeek-R1-Distill (Llama-8B / Qwen-7B): Reasoning Trace Distillation — 'Think Token'ları Öğrenmek

1. R1 Nedir ve Distill Niye?#

DeepSeek-R1 (Jan 2025):
  • 671B MoE (37B active)
  • Pure RL training (SFT'siz, GRPO ile)
  • Self-emerging reasoning behavior (`...`)
  • AIME / MATH-500 / GPQA'da o1 seviye
Distilled versiyonlar:
  • DeepSeek-R1-Distill-Llama-8B
    — Llama 3.1 8B base + R1 trace SFT
  • DeepSeek-R1-Distill-Qwen-7B
    — Qwen 2.5 7B base + R1 trace SFT
  • DeepSeek-R1-Distill-Qwen-14B
    / 32B / Llama-70B
Distill avantajı: 671B server-class'ı 7-32B consumer-class'a sıkıştır → reasoning capability + edge deploy.

2. Think Token Format#

<|begin_of_text|><|start_header_id|>user<|end_header_id|> Bir kareye eşit alanlı kare örtmek için kaç tane var? <|eot_id|><|start_header_id|>assistant<|end_header_id|> <think> Soruyu netleştir: "eşit alanlı kare" — yani aynı boyutlu küçük kareler... Düşünce 1: 1 (tek kare yeterli) Düşünce 2: ama "kaç tane" sorusu → trivial mi? Belki problem yazımı eksik... </think> Sorunuzu açıklayabilir misiniz? <|eot_id|>
Anatomi:
  • <think>...</think>
    arasındaki içerik internal reasoning
  • Train zamanı bu kısmı bütün loss'a dahil et (model thinking'i öğrenir)
  • Inference zamanı: kullanıcı genelde sadece
    </think>
    sonrasını görmek ister
python
# === Kendi R1-style FT'ni yap — DeepSeek-R1 ile distill ===
# Önce: R1'i çalıştır + traces topla (cloud veya HF API)
# Sonra: bu traces ile küçük base'i FT et
 
# Adım 1 — 1000 problem için R1 traces toplama (örnek, gerçekte cloud/API)
import openai # placeholder; gerçekte DeepSeek API
from datasets import load_dataset
 
math_problems = load_dataset("openai/gsm8k", "main", split="train[:1000]")
r1_traces = []
for problem in math_problems:
response = call_r1_api(problem["question"]) # returns "<think>...</think>...answer..."
r1_traces.append({
"question": problem["question"],
"r1_response": response, # includes <think>
"gold_answer": problem["answer"],
})
 
# Adım 2 — Trace dataset'ini SFT format'ına çevir
def to_sft(ex):
messages = [
{"role": "user", "content": ex["question"]},
{"role": "assistant", "content": ex["r1_response"]}, # think dahil
]
return {"text": tok.apply_chat_template(messages, tokenize=False)}
 
trace_ds = trace_ds.map(to_sft, num_proc=8)
 
# Adım 3 — Llama 3.1 8B'yi traces ile FT et (think'i öğrenecek)
from unsloth import FastLanguageModel
model, tok = FastLanguageModel.from_pretrained(
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
max_seq_length=8192, # think traces uzun
load_in_4bit=True, dtype="bfloat16",
)
model = FastLanguageModel.get_peft_model(
model, r=64, # rank yüksek — reasoning capacity
lora_alpha=128, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)
 
# Tokenizer'a <think>, </think> special token'larını ekle
new_tokens = ["<think>", "</think>"]
n_added = tok.add_tokens(new_tokens, special_tokens=True)
model.resize_token_embeddings(len(tok))
# Yeni embedding'ler için byte-decomp init (Ders 2.2)
 
from trl import SFTTrainer, SFTConfig
cfg = SFTConfig(
output_dir="llama-8b-reasoning",
num_train_epochs=3,
per_device_train_batch_size=1, # 8K seq → batch=1
gradient_accumulation_steps=8,
learning_rate=1e-4,
bf16=True, optim="paged_adamw_8bit",
max_seq_length=8192, packing=False, # think trace bütünlüğü
dataset_text_field="text",
logging_steps=5, save_steps=100, report_to="wandb",
)
SFTTrainer(model=model, tokenizer=tok, train_dataset=trace_ds, args=cfg).train()
Kendi R1-style reasoning FT'i — 1000 trace ile

3. Reasoning Eval — AIME / GSM8K / MATH-500#

def reasoning_eval(model, tok, questions, gold_answers): correct = 0 for q, gold in zip(questions, gold_answers): prompt = tok.apply_chat_template( [{"role": "user", "content": q}], tokenize=False, add_generation_prompt=True, ) out = model.generate(tok(prompt, return_tensors="pt").to("cuda").input_ids, max_new_tokens=4096, temperature=0.6, do_sample=True) response = tok.decode(out[0], skip_special_tokens=True) # Extract final answer (after </think>) if "</think>" in response: answer_text = response.split("</think>")[-1] else: answer_text = response # Extract numerical answer import re match = re.search(r"(\\d+(?:\\.\\d+)?)", answer_text) if match and abs(float(match.group(1)) - gold) < 1e-6: correct += 1 return correct / len(questions) acc = reasoning_eval(model, tok, gsm8k_test_q, gsm8k_test_a) print(f"GSM8K accuracy: {acc:.1%}") # Cookbook beklenen (Llama 8B + 1000 R1 traces): # Base: 84.5% # After distill: 88.4%
✅ Teslim
  1. DeepSeek-R1 API ile 200 GSM8K problem için trace topla. 2) Llama 8B'yi bu traces ile FT et. 3) GSM8K accuracy'yi ölç. 4) Sonraki ders: 3.10 — Yi-1.5 / InternLM2.5 / Aya Expanse.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content

Connected pillar topics

Pillar topics this article maps to