Custom autograd.Function ve PyTorch Internals: Kendi Gradient'lerini Yaz
PyTorch autograd'ı extend etmek: torch.autograd.Function subclass'ları, custom forward/backward, ctx ile state saklama, gradcheck doğrulaması, custom CUDA/Triton kernel wrap (preview), FlashAttention block matmul mini-implementasyon, second-order gradients ve gradgradcheck.
Şükrü Yusuf KAYA
55 dakikalık okuma
İleri🔧 PyTorch'u içeriden değiştirmek
Bu dersin sonunda kendi gradient kuralını yazıp PyTorch autograd'a entegre edebileceksin. Bu yetenek FlashAttention, Triton kernel'ları, mixed precision custom ops, quantization-aware training gibi advanced patterns için zorunlu. Modül 37 (CUDA/Triton) için warmup.
Ders Haritası (Detaylı)#
- Niye custom autograd? 6 senaryo
- anatomisi
torch.autograd.Function - İlk örnek: — basit forward/backward
SquareFunction - ile state saklama — save_for_backward
ctx - Multi-input/output Function
- — sayısal doğrulama
torch.autograd.gradcheck - Second-order: ve higher-order
gradgradcheck - Numerik stabil custom op: log-sum-exp örneği
- CUDA kernel wrap (preview Modül 37'den)
- FlashAttention'ın iskeleti — block-wise softmax
- ile op kaydı
torch.library - Production patterns: kernel fusion, dispatcher
1. Niye Custom Autograd?#
Senaryolar:
a) Performance — fused operations#
x.pow(2).sum()b) Numerik stabilite#
log(softmax(x))log_softmaxc) Custom CUDA/Triton kernels#
FlashAttention, RoPE, RMSNorm fused — bunların autograd PyTorch standart op'larıyla efficient yapılamaz. Custom kernel + custom backward.
d) Non-differentiable approximation#
Quantization (INT8): forward'da round(), backward'da identity gradient (straight-through estimator). PyTorch'un round'u zero-gradient verir.
e) Higher-order optimization#
HVP için — bazen manuel implementasyon daha verimli.
grad(grad(f))f) Hardware-specific ops#
TPU XLA call, custom NPU instruction — autograd entegrasyonu özel API gerektirir.
Bu kursta#
LoRA backward (Modül 21), FlashAttention backward (Modül 33), quantization-aware training (Modül 32), custom optimizers (Modül 17) hep custom autograd kullanır.
2. torch.autograd.Function Anatomisi#
torch.autograd.FunctionPyTorch'ta custom backward yazmak için iki yol:
Yol 1: torch.autograd.Function (önerilen)#
torch.autograd.FunctionStatic method'lu class.
Yol 2: Hook-based (legacy)#
Tensor'lara ekle — debug için iyi, production için değil.
register_hookYol 1 iskeleti#
import torch class MyFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, *args, **kwargs): # ctx: context — backward için bilgi sakla ctx.save_for_backward(input) # tensor'ları sakla ctx.some_int_param = args[0] # non-tensor için attr result = ... # forward hesabı return result @staticmethod def backward(ctx, grad_output): # grad_output: ∂L/∂result, shape == result.shape input, = ctx.saved_tensors # ∂L/∂input hesapla grad_input = ... return grad_input # her forward arg'ı için bir gradient (sırayla)
Kullanım#
y = MyFunction.apply(x, 42) # NOT: .apply(), constructor değil!
Önemli kurallar#
- zorunlu
@staticmethod - ilk argüman her ikisi de
ctx - forward'da implicit — gradient'siz çalışıyorsun
@torch.no_grad() - backward'da return sayısı = forward arg sayısı — non-tensor için döndür
None - Kullanım: ,
Cls.apply(...)değilCls(...)
python
import torch class SquareFunction(torch.autograd.Function): """y = x^2 — eğitim örneği.""" @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x ** 2 @staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors # dy/dx = 2x; chain rule: dL/dx = dL/dy * 2x grad_input = grad_output * 2 * x return grad_input # Kullanımx = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)y = SquareFunction.apply(x)print("y:", y) # [1., 4., 9.] # Sum + backwardy.sum().backward()print("grad:", x.grad) # [2., 4., 6.] = 2x # Karşılaştırma: built-inx2 = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)y2 = x2 ** 2y2.sum().backward()print("builtin grad:", x2.grad) # Aynıİlk custom autograd Function — square.
3. ctx ile State Saklama — save_for_backward#
ctxsave_for_backwardBackward için forward'dan bilgi taşımak gerekir. İki yol:
a) Tensor'lar: ctx.save_for_backward(*tensors)#
ctx.save_for_backward(*tensors)ctx.save_for_backward(input, weight) # ... input, weight = ctx.saved_tensors
PyTorch bunu otomatik olarak versionlar — eğer tensor in-place modify edilirse hata fırlatır.
b) Non-tensor: ctx.attribute#
ctx.attributectx.kernel_size = 3 ctx.bias_used = True # ... ks = ctx.kernel_size
Memory dikkat#
save_for_backwardRecompute trick (activation checkpointing)#
Memory'i azaltmak için: save_for_backward yapmak yerine, backward içinde yeniden hesapla. Trade-off: memory ↓, compute ↑.
@staticmethod def forward(ctx, x, w): ctx.x = x # bunu sakla ctx.w = w # intermediate'i saklama return x @ w + relu(x @ w) @staticmethod def backward(ctx, grad_output): x, w = ctx.x, ctx.w # forward'ı tekrarla intermediate = relu(x @ w) # ... gradient hesabı ...
python
import torch class LinearWithReLUFunction(torch.autograd.Function): """y = ReLU(x @ W + b) — birleşik kernel.""" @staticmethod def forward(ctx, x, W, b): # x: (B, in_features), W: (out_features, in_features), b: (out_features,) z = x @ W.T + b y = torch.clamp(z, min=0) # ReLU # Backward için sakla ctx.save_for_backward(x, W, z) return y @staticmethod def backward(ctx, grad_output): x, W, z = ctx.saved_tensors # ReLU türevi relu_grad = (z > 0).to(grad_output.dtype) grad_z = grad_output * relu_grad # Linear gradient grad_x = grad_z @ W # ∂L/∂x = grad_z @ W grad_W = grad_z.transpose(-2, -1) @ x # ∂L/∂W = grad_z^T @ x grad_b = grad_z.sum(0) # ∂L/∂b = sum over batch return grad_x, grad_W, grad_b # TestB, in_f, out_f = 4, 8, 5x = torch.randn(B, in_f, requires_grad=True)W = torch.randn(out_f, in_f, requires_grad=True)b = torch.randn(out_f, requires_grad=True) y = LinearWithReLUFunction.apply(x, W, b)y.sum().backward() # Karşılaştır: built-inimport torch.nn.functional as Fx2 = x.detach().clone().requires_grad_(True)W2 = W.detach().clone().requires_grad_(True)b2 = b.detach().clone().requires_grad_(True)y2 = F.relu(F.linear(x2, W2, b2))y2.sum().backward() print("x grad diff:", (x.grad - x2.grad).abs().max().item())print("W grad diff:", (W.grad - W2.grad).abs().max().item())print("b grad diff:", (b.grad - b2.grad).abs().max().item())# Hepsi ~1e-7 (FP32 hassasiyeti)Multi-input + multi-output custom Function — bit-exact karşılaştırma.
4. torch.autograd.gradcheck — Doğrulama#
torch.autograd.gradcheckManuel backward yazınca mutlaka gradcheck:
from torch.autograd import gradcheck x = torch.randn(3, requires_grad=True, dtype=torch.float64) input = (x,) # gradcheck — analitik gradient ile sayısal (finite difference) karşılaştır test = gradcheck(SquareFunction.apply, input, eps=1e-6, atol=1e-4) print(test) # True → backward doğru
Önemli detaylar#
- — FP32 hassasiyeti gradcheck için yetmez
dtype=torch.float64 - Küçük tensor'lar — gradcheck O(n²) (her input element için ayrı FD)
- ve
eps:atol,eps=1e-6typicalatol=1e-4
gradgradcheck (second-order)#
Eğer Function higher-order autograd destekliyorsa:
from torch.autograd import gradgradcheck # Backward'dan da gradient alabilir miyiz? test2 = gradgradcheck(SquareFunction.apply, input)
Pratikte higher-order destekli Function nadir. Çoğu zaman ile workaround edilir.
create_graph=True5. Numerik Stabil Custom Op — log_sum_exp Örneği#
log_sum_expSaf overflow:
log(sum(exp(x)))x = torch.tensor([1000.0, 2.0, 3.0]) torch.log(torch.exp(x).sum()) # inf, ardından NaN
Custom Function ile stabil versiyon:
class LogSumExp(torch.autograd.Function): @staticmethod def forward(ctx, x, dim=-1): m = x.max(dim=dim, keepdim=True)[0] shifted = x - m exp_shifted = torch.exp(shifted) sum_exp = exp_shifted.sum(dim=dim, keepdim=True) result = m.squeeze(dim) + torch.log(sum_exp.squeeze(dim)) # softmax probabilities backward için lazım softmax = exp_shifted / sum_exp ctx.save_for_backward(softmax) ctx.dim = dim return result @staticmethod def backward(ctx, grad_output): softmax, = ctx.saved_tensors # d/dx log_sum_exp = softmax(x) grad_input = grad_output.unsqueeze(ctx.dim) * softmax return grad_input, None # dim için None # Test x = torch.tensor([[1000.0, 2.0, 3.0]], requires_grad=True) y = LogSumExp.apply(x, -1) print("y:", y) # 1000.0 (stabil!) y.sum().backward() print("grad:", x.grad) # softmax(x) ≈ [1, 0, 0]
6. CUDA Kernel Wrap (Modül 37 Preview)#
Custom kernel'lerle ciddi performans. Yol:
# my_kernel.cu — CUDA C++ extern "C" __global__ void my_kernel(float* input, float* output, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) output[idx] = input[idx] * input[idx]; }
# Python wrap from torch.utils.cpp_extension import load_inline my_kernel = load_inline( name='my_kernel', cpp_sources='', cuda_sources=[open('my_kernel.cu').read()], functions=['my_kernel'], ) class MyCudaSquare(torch.autograd.Function): @staticmethod def forward(ctx, x): out = torch.empty_like(x) my_kernel.my_kernel(x, out, x.numel()) ctx.save_for_backward(x) return out @staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors return grad_output * 2 * x # Burada da kernel olabilir
Alternatif: Triton (Modül 37)#
import triton import triton.language as tl @triton.jit def square_kernel(x_ptr, y_ptr, N, BLOCK: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < N x = tl.load(x_ptr + offsets, mask=mask) y = x * x tl.store(y_ptr + offsets, y, mask=mask)
Triton CUDA C++'tan daha temiz, Python-syntax. Modül 37'de detayda.
7. FlashAttention'ın İskeleti — Mini İmplementasyon#
FlashAttention (Dao 2022) attention'ı block-wise hesaplıyor — full softmax matrix'i sakla değil, online log-sum-exp ile devam et.
Kavramsal akış#
N = sequence length, d = head dim Memory: O(N) yerine O(N²)
def flash_attention_naive(Q, K, V, B_r=64, B_c=64): """ Mini FlashAttention — pedagojik. Q, K, V: (B, H, N, d) """ B, H, N, d = Q.shape O = torch.zeros_like(Q) L = torch.zeros(B, H, N, device=Q.device) # log-sum-exp M = torch.full((B, H, N), -float('inf'), device=Q.device) # max n_r = (N + B_r - 1) // B_r n_c = (N + B_c - 1) // B_c for i in range(n_r): # Q bloğunu yükle i_start, i_end = i * B_r, min((i + 1) * B_r, N) Q_i = Q[:, :, i_start:i_end] # (B, H, B_r, d) m_i = M[:, :, i_start:i_end].clone() l_i = L[:, :, i_start:i_end].clone() O_i = O[:, :, i_start:i_end].clone() for j in range(n_c): j_start, j_end = j * B_c, min((j + 1) * B_c, N) K_j = K[:, :, j_start:j_end] V_j = V[:, :, j_start:j_end] # Local scores S_ij = (Q_i @ K_j.transpose(-2, -1)) / (d ** 0.5) m_ij = S_ij.max(dim=-1, keepdim=True)[0] P_ij = torch.exp(S_ij - m_ij) l_ij = P_ij.sum(dim=-1, keepdim=True) # Update m, l with online softmax m_new = torch.max(m_i.unsqueeze(-1), m_ij) l_new = (torch.exp(m_i.unsqueeze(-1) - m_new) * l_i.unsqueeze(-1) + torch.exp(m_ij - m_new) * l_ij) # Update O O_i = (torch.exp(m_i.unsqueeze(-1) - m_new) * l_i.unsqueeze(-1) * O_i + torch.exp(m_ij - m_new) * (P_ij @ V_j)) / l_new m_i = m_new.squeeze(-1) l_i = l_new.squeeze(-1) O[:, :, i_start:i_end] = O_i return O
Bu pedagojik versiyon. Gerçek FlashAttention CUDA kernel'larda — Modül 33'te detayda.
8. torch.library ile Op Kaydı (2.0+)#
torch.libraryPyTorch 2.0+'de API'siyle custom op'lar PyTorch dispatcher'a kaydedilebilir:
torch.libraryimport torch # Op tanımla @torch.library.custom_op("mylib::square", mutates_args=()) def square(x: torch.Tensor) -> torch.Tensor: return x ** 2 # Backward kaydet @square.register_fake def _(x): # Meta function — shape inference için return torch.empty_like(x) @square.register_autograd def _backward(ctx, grad): x, = ctx.saved_tensors return grad * 2 * x # Setup context @square.register_save_inputs_for_backward def _save(x): return (x,) # Kullanım x = torch.randn(3, requires_grad=True) y = torch.ops.mylib.square(x)
Avantajlar#
- ile fully composable
torch.compile - Multi-backend dispatch (CPU/CUDA/MPS)
- ONNX export desteği
- Inspector dostu
9. Production Patterns#
Kernel fusion için custom Function#
Birden çok ufak op → tek custom Function + fused kernel. Memory traffic azaltma + GPU launch overhead düşürme.
Activation checkpointing#
torch.utils.checkpointfrom torch.utils.checkpoint import checkpoint # Eğitimde memory tasarrufu için y = checkpoint(my_heavy_module, x, use_reentrant=False)
Quantization-Aware Training (QAT)#
Forward: quantize → dequantize (fake quantization).
Backward: straight-through estimator — gradient'i quantize gibi davranmadan geçir.
class FakeQuant(torch.autograd.Function): @staticmethod def forward(ctx, x, scale, n_bits=8): q = torch.round(x / scale).clamp(-(2**(n_bits-1)), 2**(n_bits-1)-1) return q * scale @staticmethod def backward(ctx, grad_output): # Straight-through: gradient'i identity ile geç return grad_output, None, None
Modül 32 (Quantization) detayda.
10. Mini Egzersizler#
-
activation:
SwiGLUwhere x split halves. Custom Function ile yaz, gradcheck'le doğrula.f(x) = swish(x_1) * x_2 -
Online softmax: bir vektör için streaming softmax. Memory O(1) (sadece running max + sum) — backward'da nasıl?
-
Gradient clipping custom:ise scale. Forward kolay, backward chain rule. Custom Function gerek mi?
||x|| > τ -
Mixed precision custom: forward FP16, backward FP32. Custom Function bunu nasıl yönetir?
-
gradcheck fail:backward'ında
SquareFunctionyerine2*xyazarsan ne olur? Hata mesajını oku.x
Bu Derste Neler Öğrendik?#
✓ 6 senaryo custom autograd gerektiriyor
✓ anatomisi (static method'lar)
✓ + non-tensor attribute
✓ Multi-input/output Function — gradient sayısı = forward arg sayısı
✓ ile bit-exact doğrulama (FP64 önerilen)
✓ Higher-order:
✓ Numerik stabil custom op (log-sum-exp)
✓ CUDA/Triton kernel wrap (Modül 37 preview)
✓ FlashAttention iskeleti — block-wise online softmax
✓ 2.0+ API — dispatcher entegrasyonu
✓ Production patterns: checkpoint, QAT, mixed precision
torch.autograd.Functionctx.save_for_backwardgradcheckgradgradchecktorch.library🎉 Modül 2 Tamamlandı!#
Toplam 6 ders, ~270 dk içerik. NumPy mühendisliğinden custom autograd'a kadar, PyTorch'un kara kutusu artık açık.
Sıradaki Modül#
Modül 3 — Derin Öğrenmenin Felsefi Tarihi
Perceptron'dan transformer'a 70 yıllık yolculuk. Connectionism vs symbolic, AI Winter'ları, AlexNet patlaması, "Attention Is All You Need". Mevcut hype'ı tarihsel perspektifte konumlandıracağız.
Sık Sorulan Sorular
**`nn.Module`**: parametre içeren, stateful, training/eval modu olan katmanlar (Linear, Conv, BatchNorm). **`torch.autograd.Function`**: stateless gradient hesabını override etmek için tek bir op. Çoğu zaman ikisi birlikte: `nn.Module` parametre + state tutar, içinden `Function.apply()` çağırır. Örnek: FlashAttention `nn.Module` olarak paketlenir ama içinde custom `Function` ile fused kernel'i çağırır.
Yorumlar & Soru-Cevap
(0)Yorum yazmak için giriş yap.
Yorumlar yükleniyor...
İlgili İçerikler
Modül 0: Kurs Çerçevesi ve Atölye Kurulumu
LLM Engineer Kimdir? Junior'dan Staff'a Yapay Zekâ Mühendisliği Kariyer Haritası
Öğrenmeye BaşlaModül 0: Kurs Çerçevesi ve Atölye Kurulumu
Kurs Felsefesi: Neden Bu Yol, Neden Bu Sıra — 8 Aylık Müfredatın İskeleti
Öğrenmeye BaşlaModül 0: Kurs Çerçevesi ve Atölye Kurulumu