SFT on Reasoning Traces: Llama-8B + R1-Distilled Traces (8K → 32K Context)
Reasoning trace dataset hazırsa SFT teknik olarak basit ama detay önemli: \<think\> token vocab'a ekleme, embedding init, context length 32K (R1 traces 5-15K token), loss masking (think tokens loss'a girer veya girmez?), epoch count. RTX 4090 + Llama 3.1 8B + 1000 R1 trace 1 epoch ~50 dakika.
Şükrü Yusuf KAYA
28 dakikalık okuma
İleripython
# === Reasoning SFT — Llama 3.1 8B + R1 traces ===from unsloth import FastLanguageModelfrom trl import SFTTrainer, SFTConfigfrom datasets import load_dataset # 1. Model + tokenizermodel, tok = FastLanguageModel.from_pretrained( "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", max_seq_length=32768, # uzun think traces dtype="bfloat16", load_in_4bit=True,) # 2. Add <think> tokensnew_tokens = ["<think>", "</think>"]n_added = tok.add_tokens(new_tokens, special_tokens=True)model.resize_token_embeddings(len(tok)) # Byte-decomp init (Part II 2.2)import torchemb = model.get_input_embeddings().weight.dataold_vocab = len(tok) - n_addedfor i, tok_str in enumerate(new_tokens): # </think> için "</", ">" parçalarının ortalaması decomp = tok.encode(tok_str, add_special_tokens=False) if decomp: emb[old_vocab + i] = emb[decomp].mean(dim=0).to(torch.bfloat16) # 3. LoRA — reasoning capacity için yüksek rankmodel = FastLanguageModel.get_peft_model( model, r=64, lora_alpha=128, lora_dropout=0.05, target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], use_gradient_checkpointing="unsloth",) # 4. R1 traces datasettraces = load_dataset("json", data_files="r1_traces_gsm8k.jsonl", split="train")def format(ex): messages = [ {"role": "user", "content": ex["question"]}, {"role": "assistant", "content": ex["response"]}, # <think> dahil ] return {"text": tok.apply_chat_template(messages, tokenize=False)}traces = traces.map(format) # 5. SFT — long context için batch=1cfg = SFTConfig( output_dir="llama-3.1-8b-reasoning", num_train_epochs=3, per_device_train_batch_size=1, gradient_accumulation_steps=8, learning_rate=1e-4, bf16=True, optim="paged_adamw_8bit", max_seq_length=32768, packing=False, # think trace bütünlüğü dataset_text_field="text", logging_steps=5, report_to="wandb",)SFTTrainer(model=model, tokenizer=tok, train_dataset=traces, args=cfg).train() # Bench:# - Wall-clock: 50-60 dk (1000 trace)# - GSM8K accuracy: Llama 8B base 84.5 → SFT 88.4# - GPQA: 22.0 → 24.5Reasoning SFT — Llama 3.1 8B + R1 traces
✅ Teslim
- 1000 R1 trace ile Llama 8B SFT. 2) GSM8K + GPQA bench. 3) Sonraki ders: 12.4 — GRPO RL Stage.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
İlgili İçerikler
Part 0 — Engineering Foundations
Fine-Tuning Cookbook'a Hoş Geldin: Sistematik, Stage Taksonomisi ve Reproducibility Kontratı
Öğrenmeye BaşlaPart 0 — Engineering Foundations
Reproducibility Stack: Seeds, cuDNN Flags ve Deterministic CUDA — 'Sende Niye Çalışıyor Bende Çalışmıyor' Sorununu Bitir
Öğrenmeye BaşlaPart 0 — Engineering Foundations