Skip to content

Custom autograd.Function and PyTorch Internals: Write Your Own Gradients

Extending PyTorch autograd: torch.autograd.Function subclasses, custom forward/backward, state saving via ctx, gradcheck validation, custom CUDA/Triton kernel wrap (preview), FlashAttention block matmul mini-implementation, second-order gradients and gradgradcheck.

Şükrü Yusuf KAYA
55 min read
Advanced
Custom autograd.Function ve PyTorch Internals: Kendi Gradient'lerini Yaz
🔧 PyTorch'u içeriden değiştirmek
Bu dersin sonunda kendi gradient kuralını yazıp PyTorch autograd'a entegre edebileceksin. Bu yetenek FlashAttention, Triton kernel'ları, mixed precision custom ops, quantization-aware training gibi advanced patterns için zorunlu. Modül 37 (CUDA/Triton) için warmup.

Ders Haritası (Detaylı)#

  1. Niye custom autograd? 6 senaryo
  2. torch.autograd.Function
    anatomisi
  3. İlk örnek:
    SquareFunction
    — basit forward/backward
  4. ctx
    ile state saklama
    — save_for_backward
  5. Multi-input/output Function
  6. torch.autograd.gradcheck
    — sayısal doğrulama
  7. Second-order:
    gradgradcheck
    ve higher-order
  8. Numerik stabil custom op: log-sum-exp örneği
  9. CUDA kernel wrap (preview Modül 37'den)
  10. FlashAttention'ın iskeleti — block-wise softmax
  11. torch.library
    ile op kaydı
  12. Production patterns: kernel fusion, dispatcher

1. Niye Custom Autograd?#

Senaryolar:

a) Performance — fused operations#

x.pow(2).sum()
2 op (pow, sum). Tek bir kernel'le birleştirebilirsin → memory bandwidth yarı. Modern GPU'da bu kritik.

b) Numerik stabilite#

log(softmax(x))
saf hesap underflow yapar. Custom
log_softmax
log-sum-exp trick kullanır. (PyTorch zaten yapıyor ama custom kernel için gerek.)

c) Custom CUDA/Triton kernels#

FlashAttention, RoPE, RMSNorm fused — bunların autograd PyTorch standart op'larıyla efficient yapılamaz. Custom kernel + custom backward.

d) Non-differentiable approximation#

Quantization (INT8): forward'da round(), backward'da identity gradient (straight-through estimator). PyTorch'un round'u zero-gradient verir.

e) Higher-order optimization#

HVP için
grad(grad(f))
— bazen manuel implementasyon daha verimli.

f) Hardware-specific ops#

TPU XLA call, custom NPU instruction — autograd entegrasyonu özel API gerektirir.

Bu kursta#

LoRA backward (Modül 21), FlashAttention backward (Modül 33), quantization-aware training (Modül 32), custom optimizers (Modül 17) hep custom autograd kullanır.

2.
torch.autograd.Function
Anatomisi#

PyTorch'ta custom backward yazmak için iki yol:

Yol 1:
torch.autograd.Function
(önerilen)#

Static method'lu class.

Yol 2: Hook-based (legacy)#

Tensor'lara
register_hook
ekle — debug için iyi, production için değil.

Yol 1 iskeleti#

import torch class MyFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, *args, **kwargs): # ctx: context — backward için bilgi sakla ctx.save_for_backward(input) # tensor'ları sakla ctx.some_int_param = args[0] # non-tensor için attr result = ... # forward hesabı return result @staticmethod def backward(ctx, grad_output): # grad_output: ∂L/∂result, shape == result.shape input, = ctx.saved_tensors # ∂L/∂input hesapla grad_input = ... return grad_input # her forward arg'ı için bir gradient (sırayla)

Kullanım#

y = MyFunction.apply(x, 42) # NOT: .apply(), constructor değil!

Önemli kurallar#

  1. @staticmethod
    zorunlu
  2. ctx
    ilk argüman
    her ikisi de
  3. forward'da
    @torch.no_grad()
    implicit
    — gradient'siz çalışıyorsun
  4. backward'da return sayısı = forward arg sayısı — non-tensor için
    None
    döndür
  5. Kullanım:
    Cls.apply(...)
    ,
    Cls(...)
    değil
python
import torch
 
class SquareFunction(torch.autograd.Function):
"""y = x^2 — eğitim örneği."""
 
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
 
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# dy/dx = 2x; chain rule: dL/dx = dL/dy * 2x
grad_input = grad_output * 2 * x
return grad_input
 
# Kullanım
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = SquareFunction.apply(x)
print("y:", y) # [1., 4., 9.]
 
# Sum + backward
y.sum().backward()
print("grad:", x.grad) # [2., 4., 6.] = 2x
 
# Karşılaştırma: built-in
x2 = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y2 = x2 ** 2
y2.sum().backward()
print("builtin grad:", x2.grad) # Aynı
İlk custom autograd Function — square.

3.
ctx
ile State Saklama —
save_for_backward
#

Backward için forward'dan bilgi taşımak gerekir. İki yol:

a) Tensor'lar:
ctx.save_for_backward(*tensors)
#

ctx.save_for_backward(input, weight) # ... input, weight = ctx.saved_tensors
PyTorch bunu otomatik olarak versionlar — eğer tensor in-place modify edilirse hata fırlatır.

b) Non-tensor:
ctx.attribute
#

ctx.kernel_size = 3 ctx.bias_used = True # ... ks = ctx.kernel_size

Memory dikkat#

save_for_backward
tensor'a reference tutar — graph free'lenene kadar memory'de kalır. Büyük intermediate'ler için bilinçli ol.

Recompute trick (activation checkpointing)#

Memory'i azaltmak için: save_for_backward yapmak yerine, backward içinde yeniden hesapla. Trade-off: memory ↓, compute ↑.
@staticmethod def forward(ctx, x, w): ctx.x = x # bunu sakla ctx.w = w # intermediate'i saklama return x @ w + relu(x @ w) @staticmethod def backward(ctx, grad_output): x, w = ctx.x, ctx.w # forward'ı tekrarla intermediate = relu(x @ w) # ... gradient hesabı ...
python
import torch
 
class LinearWithReLUFunction(torch.autograd.Function):
"""y = ReLU(x @ W + b) — birleşik kernel."""
 
@staticmethod
def forward(ctx, x, W, b):
# x: (B, in_features), W: (out_features, in_features), b: (out_features,)
z = x @ W.T + b
y = torch.clamp(z, min=0) # ReLU
# Backward için sakla
ctx.save_for_backward(x, W, z)
return y
 
@staticmethod
def backward(ctx, grad_output):
x, W, z = ctx.saved_tensors
# ReLU türevi
relu_grad = (z > 0).to(grad_output.dtype)
grad_z = grad_output * relu_grad
# Linear gradient
grad_x = grad_z @ W # ∂L/∂x = grad_z @ W
grad_W = grad_z.transpose(-2, -1) @ x # ∂L/∂W = grad_z^T @ x
grad_b = grad_z.sum(0) # ∂L/∂b = sum over batch
return grad_x, grad_W, grad_b
 
# Test
B, in_f, out_f = 4, 8, 5
x = torch.randn(B, in_f, requires_grad=True)
W = torch.randn(out_f, in_f, requires_grad=True)
b = torch.randn(out_f, requires_grad=True)
 
y = LinearWithReLUFunction.apply(x, W, b)
y.sum().backward()
 
# Karşılaştır: built-in
import torch.nn.functional as F
x2 = x.detach().clone().requires_grad_(True)
W2 = W.detach().clone().requires_grad_(True)
b2 = b.detach().clone().requires_grad_(True)
y2 = F.relu(F.linear(x2, W2, b2))
y2.sum().backward()
 
print("x grad diff:", (x.grad - x2.grad).abs().max().item())
print("W grad diff:", (W.grad - W2.grad).abs().max().item())
print("b grad diff:", (b.grad - b2.grad).abs().max().item())
# Hepsi ~1e-7 (FP32 hassasiyeti)
Multi-input + multi-output custom Function — bit-exact karşılaştırma.

4.
torch.autograd.gradcheck
— Doğrulama#

Manuel backward yazınca mutlaka gradcheck:
from torch.autograd import gradcheck x = torch.randn(3, requires_grad=True, dtype=torch.float64) input = (x,) # gradcheck — analitik gradient ile sayısal (finite difference) karşılaştır test = gradcheck(SquareFunction.apply, input, eps=1e-6, atol=1e-4) print(test) # True → backward doğru

Önemli detaylar#

  1. dtype=torch.float64
    — FP32 hassasiyeti gradcheck için yetmez
  2. Küçük tensor'lar — gradcheck O(n²) (her input element için ayrı FD)
  3. eps
    ve
    atol
    :
    eps=1e-6
    ,
    atol=1e-4
    typical

gradgradcheck (second-order)#

Eğer Function higher-order autograd destekliyorsa:
from torch.autograd import gradgradcheck # Backward'dan da gradient alabilir miyiz? test2 = gradgradcheck(SquareFunction.apply, input)
Pratikte higher-order destekli Function nadir. Çoğu zaman
create_graph=True
ile workaround edilir.

5. Numerik Stabil Custom Op —
log_sum_exp
Örneği#

Saf
log(sum(exp(x)))
overflow:
x = torch.tensor([1000.0, 2.0, 3.0]) torch.log(torch.exp(x).sum()) # inf, ardından NaN
Custom Function ile stabil versiyon:
class LogSumExp(torch.autograd.Function): @staticmethod def forward(ctx, x, dim=-1): m = x.max(dim=dim, keepdim=True)[0] shifted = x - m exp_shifted = torch.exp(shifted) sum_exp = exp_shifted.sum(dim=dim, keepdim=True) result = m.squeeze(dim) + torch.log(sum_exp.squeeze(dim)) # softmax probabilities backward için lazım softmax = exp_shifted / sum_exp ctx.save_for_backward(softmax) ctx.dim = dim return result @staticmethod def backward(ctx, grad_output): softmax, = ctx.saved_tensors # d/dx log_sum_exp = softmax(x) grad_input = grad_output.unsqueeze(ctx.dim) * softmax return grad_input, None # dim için None # Test x = torch.tensor([[1000.0, 2.0, 3.0]], requires_grad=True) y = LogSumExp.apply(x, -1) print("y:", y) # 1000.0 (stabil!) y.sum().backward() print("grad:", x.grad) # softmax(x) ≈ [1, 0, 0]

6. CUDA Kernel Wrap (Modül 37 Preview)#

Custom kernel'lerle ciddi performans. Yol:
# my_kernel.cu — CUDA C++ extern "C" __global__ void my_kernel(float* input, float* output, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) output[idx] = input[idx] * input[idx]; }
# Python wrap from torch.utils.cpp_extension import load_inline my_kernel = load_inline( name='my_kernel', cpp_sources='', cuda_sources=[open('my_kernel.cu').read()], functions=['my_kernel'], ) class MyCudaSquare(torch.autograd.Function): @staticmethod def forward(ctx, x): out = torch.empty_like(x) my_kernel.my_kernel(x, out, x.numel()) ctx.save_for_backward(x) return out @staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors return grad_output * 2 * x # Burada da kernel olabilir

Alternatif: Triton (Modül 37)#

import triton import triton.language as tl @triton.jit def square_kernel(x_ptr, y_ptr, N, BLOCK: tl.constexpr): pid = tl.program_id(0) offsets = pid * BLOCK + tl.arange(0, BLOCK) mask = offsets < N x = tl.load(x_ptr + offsets, mask=mask) y = x * x tl.store(y_ptr + offsets, y, mask=mask)
Triton CUDA C++'tan daha temiz, Python-syntax. Modül 37'de detayda.

7. FlashAttention'ın İskeleti — Mini İmplementasyon#

FlashAttention (Dao 2022) attention'ı block-wise hesaplıyor — full softmax matrix'i sakla değil, online log-sum-exp ile devam et.

Kavramsal akış#

N = sequence length, d = head dim Memory: O(N) yerine O(N²)
def flash_attention_naive(Q, K, V, B_r=64, B_c=64): """ Mini FlashAttention — pedagojik. Q, K, V: (B, H, N, d) """ B, H, N, d = Q.shape O = torch.zeros_like(Q) L = torch.zeros(B, H, N, device=Q.device) # log-sum-exp M = torch.full((B, H, N), -float('inf'), device=Q.device) # max n_r = (N + B_r - 1) // B_r n_c = (N + B_c - 1) // B_c for i in range(n_r): # Q bloğunu yükle i_start, i_end = i * B_r, min((i + 1) * B_r, N) Q_i = Q[:, :, i_start:i_end] # (B, H, B_r, d) m_i = M[:, :, i_start:i_end].clone() l_i = L[:, :, i_start:i_end].clone() O_i = O[:, :, i_start:i_end].clone() for j in range(n_c): j_start, j_end = j * B_c, min((j + 1) * B_c, N) K_j = K[:, :, j_start:j_end] V_j = V[:, :, j_start:j_end] # Local scores S_ij = (Q_i @ K_j.transpose(-2, -1)) / (d ** 0.5) m_ij = S_ij.max(dim=-1, keepdim=True)[0] P_ij = torch.exp(S_ij - m_ij) l_ij = P_ij.sum(dim=-1, keepdim=True) # Update m, l with online softmax m_new = torch.max(m_i.unsqueeze(-1), m_ij) l_new = (torch.exp(m_i.unsqueeze(-1) - m_new) * l_i.unsqueeze(-1) + torch.exp(m_ij - m_new) * l_ij) # Update O O_i = (torch.exp(m_i.unsqueeze(-1) - m_new) * l_i.unsqueeze(-1) * O_i + torch.exp(m_ij - m_new) * (P_ij @ V_j)) / l_new m_i = m_new.squeeze(-1) l_i = l_new.squeeze(-1) O[:, :, i_start:i_end] = O_i return O
Bu pedagojik versiyon. Gerçek FlashAttention CUDA kernel'larda — Modül 33'te detayda.

8.
torch.library
ile Op Kaydı (2.0+)#

PyTorch 2.0+'de
torch.library
API'siyle custom op'lar PyTorch dispatcher'a kaydedilebilir:
import torch # Op tanımla @torch.library.custom_op("mylib::square", mutates_args=()) def square(x: torch.Tensor) -> torch.Tensor: return x ** 2 # Backward kaydet @square.register_fake def _(x): # Meta function — shape inference için return torch.empty_like(x) @square.register_autograd def _backward(ctx, grad): x, = ctx.saved_tensors return grad * 2 * x # Setup context @square.register_save_inputs_for_backward def _save(x): return (x,) # Kullanım x = torch.randn(3, requires_grad=True) y = torch.ops.mylib.square(x)

Avantajlar#

  • torch.compile
    ile fully composable
  • Multi-backend dispatch (CPU/CUDA/MPS)
  • ONNX export desteği
  • Inspector dostu

9. Production Patterns#

Kernel fusion için custom Function#

Birden çok ufak op → tek custom Function + fused kernel. Memory traffic azaltma + GPU launch overhead düşürme.

Activation checkpointing#

torch.utils.checkpoint
aslında custom Function — forward'da ara aktivite saklamıyor, backward'da yeniden hesaplıyor. Pattern:
from torch.utils.checkpoint import checkpoint # Eğitimde memory tasarrufu için y = checkpoint(my_heavy_module, x, use_reentrant=False)

Quantization-Aware Training (QAT)#

Forward: quantize → dequantize (fake quantization). Backward: straight-through estimator — gradient'i quantize gibi davranmadan geçir.
class FakeQuant(torch.autograd.Function): @staticmethod def forward(ctx, x, scale, n_bits=8): q = torch.round(x / scale).clamp(-(2**(n_bits-1)), 2**(n_bits-1)-1) return q * scale @staticmethod def backward(ctx, grad_output): # Straight-through: gradient'i identity ile geç return grad_output, None, None
Modül 32 (Quantization) detayda.

10. Mini Egzersizler#

  1. SwiGLU
    activation
    :
    f(x) = swish(x_1) * x_2
    where x split halves. Custom Function ile yaz, gradcheck'le doğrula.
  2. Online softmax: bir vektör için streaming softmax. Memory O(1) (sadece running max + sum) — backward'da nasıl?
  3. Gradient clipping custom:
    ||x|| > τ
    ise scale. Forward kolay, backward chain rule. Custom Function gerek mi?
  4. Mixed precision custom: forward FP16, backward FP32. Custom Function bunu nasıl yönetir?
  5. gradcheck fail:
    SquareFunction
    backward'ında
    2*x
    yerine
    x
    yazarsan ne olur? Hata mesajını oku.

Bu Derste Neler Öğrendik?#

6 senaryo custom autograd gerektiriyor ✓
torch.autograd.Function
anatomisi (static method'lar) ✓
ctx.save_for_backward
+ non-tensor attribute ✓ Multi-input/output Function — gradient sayısı = forward arg sayısı ✓
gradcheck
ile bit-exact doğrulama (FP64 önerilen) ✓ Higher-order:
gradgradcheck
Numerik stabil custom op (log-sum-exp) ✓ CUDA/Triton kernel wrap (Modül 37 preview) ✓ FlashAttention iskeleti — block-wise online softmax ✓
torch.library
2.0+ API
— dispatcher entegrasyonu ✓ Production patterns: checkpoint, QAT, mixed precision

🎉 Modül 2 Tamamlandı!#

Toplam 6 ders, ~270 dk içerik. NumPy mühendisliğinden custom autograd'a kadar, PyTorch'un kara kutusu artık açık.

Sıradaki Modül#

Modül 3 — Derin Öğrenmenin Felsefi Tarihi Perceptron'dan transformer'a 70 yıllık yolculuk. Connectionism vs symbolic, AI Winter'ları, AlexNet patlaması, "Attention Is All You Need". Mevcut hype'ı tarihsel perspektifte konumlandıracağız.

Frequently Asked Questions

**`nn.Module`**: stateful layers with parameters and training/eval modes (Linear, Conv, BatchNorm). **`torch.autograd.Function`**: stateless to override gradient computation for a single op. Often used together: `nn.Module` holds params + state, calls `Function.apply()` inside. Example: FlashAttention is packaged as `nn.Module` but internally calls custom `Function` for fused kernel.

Yorumlar & Soru-Cevap

(0)
Yorum yazmak için giriş yap.
Yorumlar yükleniyor...

Related Content