Skip to content

Sequence Packing & Variable-Length Attention: The Trick That Boosts Throughput by 40%

Padding tokens are wasted compute. Packing: concat multiple short examples into one sequence. Variable-length attention (flash_attn_varlen_func) with block-diagonal mask. TRL SFTTrainer packing=True internals, cu_seqlens tensor anatomy, throughput bench.

Şükrü Yusuf KAYA
26 min read
Advanced
Sequence Packing & Variable-Length Attention: Throughput'u %40 Artıran Trick

1. Padding'in Maliyeti#

ModeEffective tokensWasted compute
Pad-to-maxs_avg × bs(s_max - s_avg) × bs
Bucket sortyaklaşık s_avg × bs%5-10
Packings_max × bs<%5
TR datasetinde ortalama uzunluk 750, max 4096 → naïve mode'da %82 padding → compute boşa.
python
# === Variable-length packing — flash_attn_varlen_func ===
import torch
from flash_attn import flash_attn_varlen_func
 
# Pack 3 example into one sequence of length 4096
example_lens = [1500, 1200, 1100] # toplam 3800; rest padding (296 token)
cu_seqlens = torch.tensor([0, 1500, 2700, 3800], dtype=torch.int32, device="cuda")
# cu_seqlens[i+1] - cu_seqlens[i] = i-th example length
# cu_seqlens shape: [batch_size + 1]
 
max_seqlen = max(example_lens)
total_tokens = cu_seqlens[-1].item()
 
# Q, K, V shape: [total_tokens, n_heads, head_dim]
q = torch.randn(total_tokens, 32, 128, device="cuda", dtype=torch.bfloat16)
k = torch.randn(total_tokens, 32, 128, device="cuda", dtype=torch.bfloat16)
v = torch.randn(total_tokens, 32, 128, device="cuda", dtype=torch.bfloat16)
 
# Block-diagonal attention — example A can't attend to B
out = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
causal=True,
)
# out shape: [total_tokens, n_heads, head_dim]
flash_attn_varlen_func — variable-length attention

2. TRL SFTTrainer packing=True Internals#

SFTConfig(packing=True)
arka planda:
  1. Tüm dataset örneklerini concat'le (
    <eos>
    separator ile)
  2. Sliding window ile
    max_seq_length
    chunk'larına böl
  3. Her batch'te attention mask block-diagonal (Flash-Attn-varlen kullanır)
  4. Loss masking:
    <eos>
    token'larından sonraki kısımlar zaten "yeni örnek" başlangıcı
Avantaj: Tüm boşluk gider, throughput maxar. Risk: Eğer dataset chat template'i bozuksa packed örnekler arası "leakage" olur (Ders 2.5'teki FMD).

3. Bench (RTX 4090 + Llama 3.1 8B QLoRA)#

Modetokens/s effectivewall-clock 1 epoch (52K TR-Alpaca)
Pad-to-max (s=2048)43004h 12m
Bucket-sort (s=1024-2048)61002h 58m
Packing102001h 47m
Packing + Unsloth175001h 02m
✅ Teslim
  1. Aynı dataset, packing=False vs True bench. 2) cu_seqlens'i print et, anatomik anla. 3) Sonraki ders: 2.10 — Streaming & Sharded Datasets.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content