Skip to content

SFT on Reasoning Traces: Llama-8B + R1-Distilled Traces (8K → 32K Context)

If reasoning trace dataset ready, SFT technically simple but details matter: add <think> tokens to vocab, embedding init, context length 32K (R1 traces 5-15K tokens), loss masking (do think tokens contribute to loss?), epoch count. Llama 3.1 8B + 1000 R1 traces 1 epoch on RTX 4090 ~50 min.

Şükrü Yusuf KAYA
28 min read
Advanced
SFT on Reasoning Traces: Llama-8B + R1-Distilled Traces (8K → 32K Context)
python
# === Reasoning SFT — Llama 3.1 8B + R1 traces ===
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
 
# 1. Model + tokenizer
model, tok = FastLanguageModel.from_pretrained(
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
max_seq_length=32768, # uzun think traces
dtype="bfloat16", load_in_4bit=True,
)
 
# 2. Add <think> tokens
new_tokens = ["<think>", "</think>"]
n_added = tok.add_tokens(new_tokens, special_tokens=True)
model.resize_token_embeddings(len(tok))
 
# Byte-decomp init (Part II 2.2)
import torch
emb = model.get_input_embeddings().weight.data
old_vocab = len(tok) - n_added
for i, tok_str in enumerate(new_tokens):
# </think> için "</", ">" parçalarının ortalaması
decomp = tok.encode(tok_str, add_special_tokens=False)
if decomp:
emb[old_vocab + i] = emb[decomp].mean(dim=0).to(torch.bfloat16)
 
# 3. LoRA — reasoning capacity için yüksek rank
model = FastLanguageModel.get_peft_model(
model, r=64, lora_alpha=128, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
use_gradient_checkpointing="unsloth",
)
 
# 4. R1 traces dataset
traces = load_dataset("json", data_files="r1_traces_gsm8k.jsonl", split="train")
def format(ex):
messages = [
{"role": "user", "content": ex["question"]},
{"role": "assistant", "content": ex["response"]}, # <think> dahil
]
return {"text": tok.apply_chat_template(messages, tokenize=False)}
traces = traces.map(format)
 
# 5. SFT — long context için batch=1
cfg = SFTConfig(
output_dir="llama-3.1-8b-reasoning",
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=1e-4,
bf16=True, optim="paged_adamw_8bit",
max_seq_length=32768, packing=False, # think trace bütünlüğü
dataset_text_field="text",
logging_steps=5, report_to="wandb",
)
SFTTrainer(model=model, tokenizer=tok, train_dataset=traces, args=cfg).train()
 
# Bench:
# - Wall-clock: 50-60 dk (1000 trace)
# - GSM8K accuracy: Llama 8B base 84.5 → SFT 88.4
# - GPQA: 22.0 → 24.5
Reasoning SFT — Llama 3.1 8B + R1 traces
✅ Teslim
  1. 1000 R1 trace ile Llama 8B SFT. 2) GSM8K + GPQA bench. 3) Sonraki ders: 12.4 — GRPO RL Stage.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content