İçeriğe geç

GPTQ Algoritması: Optimal Brain Quantization + Hessian Update — RTX 4090'da 12 Dakikada Llama 8B

GPTQ (Frantar et al. 2022) — LLM weight quantization standardı. Optimal Brain Quantization theory (LeCun 1990), Hessian inverse update, error compensation, group quantization. RTX 4090 + auto-gptq ile Llama 3.1 8B'yi 12 dakikada int4'e quantize et. WikiText-2 perplexity delta < %2.

Şükrü Yusuf KAYA
34 dakikalık okuma
İleri
GPTQ Algoritması: Optimal Brain Quantization + Hessian Update — RTX 4090'da 12 Dakikada Llama 8B

1. Optimal Brain Quantization (OBQ) Theory#

GPTQ, Optimal Brain Surgeon (LeCun 1990) ve Optimal Brain Quantizer üzerine kurulu.

Temel fikir:#

Bir layer'ın weights
W ∈ R^{d_out × d_in}
'i quantize ederken, bir column'ı quantize et + kalan column'lar üzerinde kompenze et. "Kompenze" = quantization hatası diğer column'lara yayılmış.

Hessian:#

Quantization hatasını minimize et:
min ||W·X - Ŵ·X||²
  • X
    = layer'a giren activation (calibration set'ten)
  • Ŵ
    = quantized weight
Optimal update için Hessian
H = X X^T
ve onun inverse'i lazım. Hessian
d_in × d_in
(Llama 8B FFN için 14336 × 14336 = 200M params).

2. GPTQ Algoritması — Column-by-Column#

1. Compute Hessian H = X X^T from calibration data 2. Cholesky decomposition: H^(-1) = L L^T 3. For each column i (left to right): a. Quantize column i: w_i' = quantize(w_i) b. Compute error: e = (w_i - w_i') / L[i,i] c. Update remaining columns: W[:, i+1:] -= e × L[i, i+1:] 4. Return quantized W'
Cookbook açıklama: Her column quantize edildiğinde, oluşan hata sonraki column'larda proaktif olarak kompenze edilir. Bu naïve "quantize each weight independently"ye göre %15-30 perplexity iyileştirir.

Group-wise GPTQ:#

group_size = 128
ile her 128 column birlikte quantize edilir. Group içinde paylaşılan scale → memory tasarrufu + kalite.
python
# === GPTQ ile Llama 3.1 8B int4 quantization — RTX 4090, 12 dakika ===
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset
 
# 1. Calibration dataset (önemli! dağılım modelin gerçek inference'ıyla benzer olmalı)
calibration_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:512]")
 
tok = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
 
def to_examples(data):
examples = []
for row in data:
text = row["text"]
if len(text.strip()) > 10:
inputs = tok(text, return_tensors="pt", max_length=2048, truncation=True)
examples.append({"input_ids": inputs["input_ids"][0],
"attention_mask": inputs["attention_mask"][0]})
return examples[:128]
 
calibration = to_examples(calibration_data)
 
# 2. Quantization config
quant_cfg = BaseQuantizeConfig(
bits=4,
group_size=128,
desc_act=False, # activation reorder (true daha iyi ama yavaş)
sym=True,
true_sequential=True, # sequential layer-by-layer
)
 
# 3. Load + quantize
model = AutoGPTQForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
quantize_config=quant_cfg,
torch_dtype="bfloat16",
device_map="cuda",
)
 
# 4. Quantize — RTX 4090, ~12 dakika
model.quantize(calibration)
 
# 5. Save (~4.5 GB instead of 16 GB)
model.save_quantized("llama-3.1-8b-int4-gptq")
tok.save_pretrained("llama-3.1-8b-int4-gptq")
 
# Bench:
# - bf16 size: 16 GB
# - GPTQ int4 size: 4.5 GB (-72%)
# - WikiText-2 perplexity: bf16 5.93 → int4 6.04 (+1.9%)
# - Inference speed (vLLM): bf16 95 tok/s → int4 165 tok/s (+74%)
auto-gptq ile Llama 3.1 8B int4 quantization

3. Calibration Dataset Mühendisliği#

GPTQ'nun kalitesi calibration data'ya çok bağlı:
CalibrationWikiText-2 PPL (post-quant)TR-MMLU (post-quant)
WikiText-2 (default)6.04 (+1.9%)32.2 (-0.2)
Random tokens (kötü)7.85 (+32%)29.8 (-2.6)
C4 multilingual6.10 (+2.8%)32.3 (-0.1)
TR text mix (TR-specific)6.18 (+4.2%) en model33.0 (+0.6) TR best
In-domain (model'in real-life prompts)5.97 (+0.7%)best for in-domain
Cookbook'un kuralı: Production deploy edeceğin model için in-domain calibration set kullan. Açık benchmark için WikiText-2 fine.
🐛 FMD — 'GPTQ sonrası model output garbage (sadece tek karakter tekrarı)'
Hipotezler: (a) calibration data quality kötü (random token, çok kısa) → Hessian incorrectly estimated. Çözüm: WikiText-2 veya C4 ile 128-512 sample. (b) `desc_act=False` ama model bunu gerektiriyor (büyük modeller için True önerilir). Çözüm: `desc_act=True` retry (yavaş ama daha doğru). (c) `true_sequential=False` kaldı → sayısal hata birikti. (d) bf16 → int4 dönüşümde NaN slip. Drill: WikiText-2 perplexity'yi her layer'dan sonra ölç, hangi layer'da divergence başlıyor bul.
✅ Teslim
  1. Llama 3.1 8B'yi GPTQ ile quantize et (12 dakika). 2) WikiText-2 perplexity delta'sını ölç. 3) vLLM ile inference throughput'u karşılaştır. 4) Sonraki ders: 10.3 — AWQ Algoritması.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

İlgili İçerikler