Autograd and Backpropagation in PyTorch
- Autograd automates the Chain Rule by building a dynamic computational graph during the forward pass and traversing it in reverse during
backward(). Understanding the graph lifecycle — construction, retention, destruction — is the prerequisite for debugging memory leaks and gradient errors. - Always call
optimizer.zero_grad()before the forward pass. PyTorch accumulates gradients by default. Forgetting this produces exponential gradient growth that silently corrupts models over dozens of epochs before the divergence becomes visible in loss metrics. - Use .item() for scalar logging and .detach().cpu() for tensor logging. Every tensor stored without detaching holds a reference to its entire computational graph, causing linear GPU memory growth that terminates in OOM.
- Autograd is PyTorch's automatic differentiation engine that records operations on tensors and computes gradients via backpropagation
- Set requires_grad=True on a tensor to track all operations performed on it in a dynamic computational graph
- Call .backward() to compute gradients — PyTorch applies the Chain Rule automatically through the recorded graph
- Always call optimizer.zero_grad() before each backward pass — PyTorch accumulates gradients by default, which is useful for gradient accumulation but catastrophic if forgotten in standard training loops
- Wrap inference code in torch.inference_mode() for maximum performance — it is 10-20% faster than torch.no_grad() because it skips version counter tracking
- The dynamic graph (Define-by-Run) means Python if/else and loops inside your model just work — unlike static graph frameworks where conditional logic requires special ops
- Gradient checkpointing trades ~30% more compute for ~60% less memory — the standard tool when a model does not fit on one GPU
Loss becomes NaN during training
torch.autograd.set_detect_anomaly(True)print([(name, p.grad.norm().item()) for name, p in model.named_parameters() if p.grad is not None])GPU memory keeps growing each epoch
torch.cuda.memory_summary(device=None, abbreviated=False)print(loss.item()) # NOT print(loss) — .item() returns a Python float with no graph referenceGradients are all zeros — model not learning
for n, p in model.named_parameters(): print(n, p.requires_grad, p.grad is not None, p.grad.norm().item() if p.grad is not None else 'NO GRAD')torch.autograd.gradcheck(model, inputs, eps=1e-6, atol=1e-4) # numerically verify gradient computationProduction Incident
optimizer.zero_grad(). PyTorch accumulates gradients by default — every backward() call adds to existing gradient values in the .grad attribute rather than replacing them. Over 50 epochs of accumulation, gradient magnitudes grew exponentially. The optimizer was applying increasingly large weight updates in directions determined by the sum of 50 epochs of gradients rather than the current batch. Weights grew until intermediate activations overflowed to floating-point infinity, which propagated to NaN in subsequent operations. Training loss appeared healthy because the optimizer's massive corrections happened to reduce the accumulated loss signal, masking the divergence until inference ran on clean production data without the distortion.optimizer.zero_grad() at the start of every training step — before the forward pass, not after. Added gradient clipping using torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) as a permanent safety net. Added NaN detection in the validation loop using torch.isnan(val_loss).any() to catch divergence within the same epoch it occurs. Added gradient norm logging to the monitoring dashboard — a spike in gradient norms is the earliest signal of accumulation bugs or learning rate issues, appearing several epochs before loss divergence becomes visible.optimizer.zero_grad() before the forward pass, not after the backward pass, because the position relative to the forward matters for gradient accumulation workflowsMonitor gradient norms during training — the norm should be stable within an order of magnitude. Sudden spikes indicate accumulation bugs, learning rate issues, or outlier batches before the loss shows any sign of divergenceAdd NaN detection in validation loops — training loss can appear healthy while the model silently diverges because accumulated gradients produce weight corrections that happen to reduce the loss signal even as weights grow unboundedGradient clipping is not optional for production training — treat max_norm=1.0 as a default starting point and tune from there. It prevents a single outlier batch from causing catastrophic weight updates that take multiple epochs to detectProduction Debug GuideCommon symptoms when gradient computation goes wrong — and where to look first
p.grad.norm().item()) for name, p in model.named_parameters() if p.grad is not None]). If gradients are zero or near-zero, check for dead ReLU neurons caused by large negative biases, weights initialized to zero, or in-place operations that broke the graph. If gradients are non-zero but the model still collapses, suspect a loss function that does not create sufficient gradient signal to discriminate between inputs.loss.item()). Run torch.cuda.memory_summary() to see which allocation is growing. Also check for retain_graph=True in a loop — it keeps every graph alive.torch.no_grad() or torch.inference_mode() wrapping the training loop accidentally, (2) in-place operations on leaf tensors that break the graph silently, (3) calling .numpy() on a tracked tensor which detaches it from the graph, (4) a disconnected computation path where the loss does not actually depend on the parameters. Use model.named_parameters() to confirm requires_grad is True for all trainable parameters.torch.profiler.profile() to identify which backward operation is the bottleneck. Common causes: naive attention implementations that materialize the full N×N attention matrix in memory, operations that create large intermediate tensors in backward but not forward, or missing gradient checkpointing on deep models where backward must traverse a very long graph. For transformers, verify you are using memory-efficient attention kernels (Flash Attention or scaled_dot_product_attention in PyTorch 2.0+).Autograd and Backpropagation in PyTorch is the engine that makes neural network training work. Understanding it at a mechanical level — not just as a black box — is what separates engineers who can debug training failures from engineers who restart the training run and hope for different results.
PyTorch uses a Define-by-Run approach where the computational graph is built dynamically during the forward pass. This means the graph structure changes with every iteration if your model contains conditional logic or variable-length sequences — something static graph frameworks handle awkwardly through specialized conditional operations that bear little resemblance to ordinary Python code.
The practical consequence: gradients are always correct for the exact computation that just ran, not a pre-compiled approximation. The trade-off is that the graph must be rebuilt each forward pass, which adds overhead that static frameworks avoid through upfront compilation. In practice, PyTorch's JIT compiler and torch.compile() (introduced in PyTorch 2.0) close most of that gap for production workloads.
This guide covers the full Autograd lifecycle: how the computational graph is built and destroyed, the most expensive mistakes in production training loops, custom gradient implementations for numerical stability, gradient checkpointing for large models, and the correct inference contexts for serving. By the end, you will be able to reason about why gradients are wrong before reaching for a debugger.
What Is Autograd and Backpropagation in PyTorch and Why Does It Exist?
Autograd exists because manual gradient derivation does not scale. A two-layer network with a sigmoid activation and cross-entropy loss already requires several lines of calculus to derive the gradients correctly. A 24-layer transformer with attention, layer norm, residual connections, and a mixture of activation functions would require weeks of derivation work — work that becomes incorrect the moment you change the architecture.
PyTorch's answer is automatic differentiation through a dynamic computational graph. Every operation you perform on a tensor with requires_grad=True gets recorded as a node in a directed acyclic graph (DAG). Each node stores references to its inputs and carries a gradient function — the local derivative of that specific operation. When you call .backward() on a scalar loss, PyTorch traverses this graph in reverse from the loss to every leaf tensor, multiplying local gradients together via the Chain Rule at each node. The result is the exact gradient of the loss with respect to every tracked parameter.
The Define-by-Run approach means the graph is built during execution, not beforehand. If your model has an if/else branch, PyTorch records whichever branch actually executed for that specific input. The gradient computation is then correct for that specific execution path. Static graph frameworks require you to express conditional logic using framework-specific operations that are compiled into a fixed graph — the gradient is correct for the compiled graph, which may differ from what you intended if the conditional logic is complex.
The practical trade-off: dynamic graph construction adds overhead on each forward pass that static frameworks avoid through upfront compilation. PyTorch 2.0's torch.compile() closes most of this gap for production workloads by compiling the dynamic graph into optimized static kernels while preserving the flexibility of Python-level control flow.
# io.thecodeforge: Standard Autograd and Backpropagation Flow # This example walks through every step of the gradient computation # so you can see what Autograd is doing at each stage. import torch # requires_grad=True marks x as a leaf tensor whose gradient we want. # Autograd will track every operation performed on x. x = torch.ones(2, 2, requires_grad=True) print(f"x.requires_grad: {x.requires_grad}") print(f"x.grad_fn: {x.grad_fn}") # None — x is a leaf, not computed from another tensor # Forward pass — each operation creates a Function node in the graph y = x + 2 print(f"y.grad_fn: {y.grad_fn}") # AddBackward0 — knows how to compute its gradient z = y * y * 3 print(f"z.grad_fn: {z.grad_fn}") # MulBackward0 out = z.mean() print(f"out.grad_fn: {out.grad_fn}") # MeanBackward0 print(f"Forward Pass Result: {out.item()}") # Backward pass — traverse the graph in reverse, applying Chain Rule at each node # PyTorch computes d(out)/dx automatically out.backward() # x.grad now contains d(out)/dx # Manual derivation: out = z.mean() = (3 * (x+2)^2).mean() # d(out)/dx = 3 * 2 * (x+2) / 4 = 3 * 2 * 3 / 4 = 4.5 at x=1 print(f"Gradients d(out)/dx:\n{x.grad}") # Expected: all 4.5 because x was initialized to ones # The graph is destroyed after backward() — calling again raises RuntimeError # This is intentional memory management, not a limitation
x.grad_fn: None
y.grad_fn: <AddBackward0 object at 0x...>
z.grad_fn: <MulBackward0 object at 0x...>
out.grad_fn: <MeanBackward0 object at 0x...>
Forward Pass Result: 27.0
Gradients d(out)/dx:
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
- Forward pass: execute operations and build the graph — each operation node stores its gradient function and holds references to its inputs
- Backward pass: start from the scalar loss and traverse the graph in reverse, calling each node's gradient function and multiplying by the upstream gradient via the Chain Rule
- The graph is rebuilt on every forward pass — it reflects the exact computation that just ran, including any Python branches that were taken
- requires_grad=True marks a tensor as a leaf node whose gradient we want accumulated in .grad after
backward() - A single .backward() call on the scalar loss computes gradients for every tracked parameter in the entire network — that is the efficiency that makes training practical
torch.compile() in PyTorch 2.0+ recovers most of it for static sub-graphs within a dynamic model.backward() that computes the gradient in a more stable form. Verify with torch.autograd.gradcheck() before deploying.torch.inference_mode() to disable graph construction entirely and save both memory and compute. Do not use no_grad() for production serving — inference_mode() is strictly faster.torch.autograd.gradcheck() to numerically verify computed gradients against finite differences. If gradcheck passes, Autograd is correct and the bug is elsewhere.The Computational Graph — Anatomy and Lifecycle
The computational graph is not an abstract concept — it is a concrete data structure that PyTorch builds in memory during your forward pass and destroys after your backward pass. Understanding its lifecycle is the prerequisite for debugging GPU memory leaks, gradient errors, and unexpected RuntimeError messages.
Construction phase: every operation on a tensor with requires_grad=True creates a Function object. This object stores references to its input tensors and implements the backward computation for that specific operation. The Function objects chain together through their next_functions pointers to form the graph. This construction happens eagerly, operation by operation, as your code executes.
Retention phase: the graph exists from the end of the forward pass until backward() is called. During this window, the graph holds references to every intermediate tensor needed for gradient computation. This is where GPU memory is allocated for all those intermediate activations. The longer the graph, the more memory it holds during retention.
Destruction phase: after backward() completes, PyTorch releases the graph by default. All Function objects are freed, and the intermediate tensors they held references to can be garbage collected. This is intentional — PyTorch assumes you only need to call backward() once per forward pass.
The three production mistakes that come from misunderstanding this lifecycle: calling backward() twice without retain_graph=True (the graph is gone after the first call), storing loss tensors in lists without .item() (holding a reference to the tensor holds the entire graph), and using retain_graph=True in a training loop without releasing it (the graph grows with each iteration).
# io.thecodeforge: Computational graph lifecycle and memory management # Understanding this code is the key to diagnosing most PyTorch memory issues import torch # === GRAPH CONSTRUCTION === # Each operation creates a Function node that chains backward through next_functions x = torch.randn(3, 3, requires_grad=True) y = x * 2 # MulBackward0 node created, stores reference to x z = y.sum() # SumBackward0 node created, stores reference to y print(f"z.grad_fn: {z.grad_fn}") # SumBackward0 print(f"z.grad_fn.next_functions: {z.grad_fn.next_functions}") # -> MulBackward0 print(f"MulBackward0.next_functions: {z.grad_fn.next_functions[0][0].next_functions}") # -> AccumulateGrad (the leaf x) # === RETENTION PHASE === # The graph is in memory between the forward pass completion and backward() # Every intermediate tensor y is held alive by the graph during this window # === DESTRUCTION — default behavior === z.backward() # Graph traversed, gradients computed, graph freed print(f"d(z)/dx:\n{x.grad}") # tensor([[2., 2., 2.], ...]) # Attempting backward again raises RuntimeError — graph is gone try: z.backward() except RuntimeError as e: print(f"Expected error: {e[:80]}...") # === retain_graph=True — use sparingly === x.grad = None # Reset gradient before next backward y2 = x * 3 z2 = y2.sum() # retain_graph=True keeps the graph alive after backward # Use ONLY when you need multiple backward passes on the same graph # (e.g., computing gradients of multiple scalar outputs separately) z2.backward(retain_graph=True) # Graph retained first_grad = x.grad.clone() print(f"After first backward (retain_graph=True):\n{first_grad}") z2.backward() # Second backward works — but gradients are ACCUMULATED print(f"After second backward (accumulated):\n{x.grad}") # x.grad is now 2x the correct gradient — accumulated, not replaced # This is why retain_graph in a loop causes exponential gradient growth # === MEMORY LEAK PATTERN — what NOT to do === # This pattern causes GPU memory to grow linearly with training steps: # loss_history = [] # for batch in dataloader: # loss = model(batch) # loss_history.append(loss) # WRONG: holds entire graph for every batch # loss.backward() # # Correct: # loss_history.append(loss.item()) # Python float, no graph reference
z.grad_fn.next_functions: ((<MulBackward0 object at 0x...>, 0),)
MulBackward0.next_functions: ((<AccumulateGrad object at 0x...>, 0),)
d(z)/dx:
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
Expected error: Trying to backward through the graph a second time (or directly access saved tensors aft...
After first backward (retain_graph=True):
tensor([[3., 3., 3.],
[3., 3., 3.],
[3., 3., 3.]])
After second backward (accumulated):
tensor([[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]])
backward() multiple times on the same graph — for example, computing gradients of multiple scalar outputs from a multi-task model separately. Even then, call it only on all but the final backward() call so the graph is released after the last pass. Never set it as a default in a training loop to 'prevent errors' without understanding what it does.backward() by default — this is correct and intentional. PyTorch engineers designed it this way because almost every training loop only calls backward() once per forward pass, and destroying the graph immediately is the safest memory management strategy.no_grad(), or retain_graph=True left in from debugging.backward() completes, and destruction to free memory.backward() by default — this is the correct behavior. retain_graph=True is a tool for specific multi-output scenarios, not a general-purpose option.Common Mistakes and How to Avoid Them
Most Autograd bugs in production follow one of five patterns. They appear across teams, seniority levels, and model types — which means they are not careless mistakes but rather counterintuitive behaviors that are easy to get wrong even when you know the framework well.
1. Forgetting optimizer.zero_grad(). PyTorch accumulates gradients by default — .backward() adds to .grad rather than replacing it. In a standard training loop where you want gradients from only the current batch, failing to zero gradients causes every batch's gradients to stack. After 50 batches, the optimizer is responding to the sum of all 50 batches' gradients, which has no relation to the current batch's signal. The weight updates overshoot exponentially. This is the bug from the production incident section — it produced stable-looking training loss while silently corrupting the model.
2. In-place operations on tracked tensors. PyTorch's version counter system detects when a tensor is modified in-place after it was recorded in the graph. If the modified value is needed for a gradient computation, the version counter mismatch raises RuntimeError. Worse, some in-place operations in specific positions in the graph corrupt gradients silently rather than raising an error — the gradients are computed for the unmodified tensor value while the forward pass used the modified value.
3. Failing to detach tensors when logging. This is the most common GPU memory leak in training loops. Storing loss.item() is correct. Storing loss without .item() holds a reference to the entire computational graph — because loss is a tensor with a grad_fn that chains back through the entire forward pass. Every batch's graph accumulates in memory until OOM.
4. Accidentally wrapping training in torch.no_grad(). This is the silent learning failure. The model forward pass runs, the loss is computed, and nothing happens. No gradient is computed. No weight is updated. Loss stays constant or varies only from batch-to-batch data differences. It looks like a learning rate that is too low, or a model that has converged immediately. The tell is that gradient norms are all zero — add gradient norm logging and this becomes visible in the first epoch.
5. Calling backward() on a non-scalar without a gradient argument. backward() is designed for scalar outputs. If your loss is a vector (common when you forget to average across the batch), PyTorch raises RuntimeError immediately. The fix is either averaging the loss — loss.mean().backward() — or passing an explicit gradient tensor matching the loss shape.
# io.thecodeforge: The correct training loop pattern — every decision annotated import torch import torch.nn as nn import torch.optim as optim model = nn.Linear(10, 1) optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.MSELoss() # Dummy data inputs = torch.randn(32, 10) # batch of 32 targets = torch.randn(32, 1) loss_history = [] # For monitoring — must store scalars, not tensors for step in range(100): # STEP 1: Zero gradients BEFORE the forward pass. # Position matters: putting this after backward() instead of before # the next forward pass is equivalent for single-step loops but breaks # gradient accumulation workflows. Always put it before forward. optimizer.zero_grad() # STEP 2: Forward pass — builds the computational graph outputs = model(inputs) loss = criterion(outputs, targets) # Scalar — averaged over batch by MSELoss # STEP 3: Log the scalar loss value WITHOUT a graph reference # loss is a tensor with requires_grad=True — storing it holds the entire graph # loss.item() returns a Python float — no graph reference, no memory leak loss_history.append(loss.item()) # Correct # loss_history.append(loss) # WRONG — holds graph reference across steps # STEP 4: Backward pass — compute gradients for all parameters loss.backward() # Graph is traversed and freed after this call # STEP 5: (Optional but recommended) Gradient clipping before optimizer step # Prevents a single bad batch from causing catastrophic weight updates torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # STEP 6: Update weights using computed gradients optimizer.step() # STEP 7: (Optional but valuable) Monitor gradient norms # A sudden spike here is the earliest signal of gradient issues if step % 10 == 0: grad_norm = sum( p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None ) ** 0.5 print(f"Step {step:3d}: loss={loss_history[-1]:.4f}, grad_norm={grad_norm:.4f}") print(f"\nFinal loss: {loss_history[-1]:.4f}") print(f"All losses are Python floats (no graph refs): {type(loss_history[0])}")
Step 10: loss=0.9873, grad_norm=0.2914
Step 20: loss=0.8102, grad_norm=0.2341
Step 30: loss=0.6934, grad_norm=0.1923
Step 40: loss=0.5821, grad_norm=0.1612
Step 50: loss=0.4923, grad_norm=0.1334
Step 60: loss=0.4201, grad_norm=0.1102
Step 70: loss=0.3614, grad_norm=0.0923
Step 80: loss=0.3102, grad_norm=0.0781
Step 90: loss=0.2693, grad_norm=0.0664
Final loss: 0.2341
All losses are Python floats (no graph refs): <class 'float'>
torch.no_grad() or torch.inference_mode() accidentally wraps the training loop, when in-place operations silently corrupt gradients without raising an error, or when requires_grad is False on parameters that should be updated.
The tell for all of these is gradient norms that are zero. Add gradient norm monitoring to every training loop — it is a one-line addition and it makes an entire class of silent failures immediately visible. A training run with zero gradient norms is not training. Full stop.zero_grad() causes gradient accumulation that grows exponentially with training steps. The model appears to train for dozens of epochs before diverging — which is exactly long enough to waste serious compute budget before anyone investigates.torch.no_grad() means weights never update. This is the most common mistake made during debugging — someone adds no_grad() to test inference behavior, forgets to remove it, and the training run produces a model that is identical to its initialization.optimizer.zero_grad() before the forward pass — PyTorch accumulates gradients by default and the accumulation is exponential in magnitude.model.parameters(), max_norm=1.0).no_grad() wrapping, in-place ops on leaf tensors, or .numpy() calls breaking the graph. If non-zero but collapsing: investigate the loss function for insufficient discriminative signal.loss.item()). Replace tensor_list.append(tensor) with tensor_list.append(tensor.detach().cpu()). Check for retain_graph=True in any loop.Custom Gradients with torch.autograd.Function
PyTorch's Autograd is correct for most operations, but correct and numerically stable are different properties. For operations involving logarithms near zero, exponentials of large values, or divisions by small numbers, the mathematically correct gradient can overflow float32 before it is computed. This is where torch.autograd.Function becomes essential.
The Function class exposes two static methods. forward() computes the output given the inputs and has access to ctx — a context object that can store tensors for use in backward() via ctx.save_for_backward(). backward() receives grad_output — the upstream gradient from the loss — and returns one gradient per input to forward(), or None for inputs that do not need gradients (like integer indices or dimension arguments).
The production use case that comes up most frequently: log-sum-exp and softmax operations. The standard computation exp(x) / sum(exp(x)) overflows in float32 for logits above approximately 88. The stable version subtracts the maximum before exponentiating: exp(x - max) / sum(exp(x - max)). The gradient of the stable version is the same mathematically but avoids the overflow in both forward and backward.
A second important use case: straight-through estimators for quantization. During the forward pass, you round continuous values to discrete bins. This operation has a gradient of zero everywhere (mathematically) which means no signal propagates back. The straight-through estimator uses a custom backward() that returns the upstream gradient unmodified — treating the quantization as an identity function for gradient purposes. This allows training of quantized models.
Before deploying any custom Function, always verify it with torch.autograd.gradcheck(). This function numerically estimates gradients using finite differences and compares them against your custom backward() output. A gradcheck pass does not guarantee your gradient is optimal, but it confirms it is correct.
# io.thecodeforge: Custom autograd Function examples # Two real production use cases: numerically stable log-sum-exp # and straight-through estimator for quantization-aware training import torch from torch.autograd import Function class StableLogSumExp(Function): """ Numerically stable log-sum-exp with a custom backward pass. Why this exists: The naive computation log(sum(exp(x))) overflows float32 when any element of x exceeds ~88. The stable version subtracts the maximum before exponentiating — mathematically identical but float32-safe. The gradient is softmax(x) * grad_output, computed in the stable form. """ @staticmethod def forward(ctx, x, dim): ctx.dim = dim # Subtract max for numerical stability — does not change the result max_val = x.max(dim=dim, keepdim=True).values shifted = x - max_val exp_sum = shifted.exp().sum(dim=dim, keepdim=True) ctx.save_for_backward(x, max_val) return max_val.squeeze(dim) + exp_sum.log().squeeze(dim) @staticmethod def backward(ctx, grad_output): x, max_val = ctx.saved_tensors dim = ctx.dim # Stable softmax — same stability trick as forward shifted = (x - max_val).exp() softmax = shifted / shifted.sum(dim=dim, keepdim=True) # Chain rule: d(LSE)/dx_i = softmax_i # Multiply by upstream gradient and broadcast back to x shape return grad_output.unsqueeze(dim) * softmax, None # None for dim argument class StraightThroughEstimator(Function): """ Straight-through estimator for quantization-aware training. Why this exists: Rounding (quantization) has zero gradient everywhere — no signal propagates backward through round() in standard Autograd. The straight-through estimator passes gradients straight through the quantization step as if it were an identity function. This allows gradient-based training of networks that use quantization at inference time. """ @staticmethod def forward(ctx, x, num_bits=8): # Quantize to num_bits — this is the discrete rounding step scale = (2 ** num_bits - 1) return x.mul(scale).round().div(scale) @staticmethod def backward(ctx, grad_output): # Pass gradient through unchanged — ignore the discontinuity # This is the straight-through estimator approximation return grad_output, None # None for num_bits # === Verification with gradcheck === # gradcheck numerically estimates gradients via finite differences # and compares against your custom backward() — always run this before deploying print("Verifying StableLogSumExp gradients...") x_check = torch.randn(4, 8, requires_grad=True, dtype=torch.float64) # float64 for gradcheck precision result = torch.autograd.gradcheck( lambda x: StableLogSumExp.apply(x, 1), (x_check,), eps=1e-6, atol=1e-4 ) print(f"StableLogSumExp gradcheck: {'PASSED' if result else 'FAILED'}") # === Production usage === x = torch.randn(3, 10, requires_grad=True) lse = StableLogSumExp.apply(x, dim=1) loss = lse.sum() loss.backward() print(f"Custom gradient shape: {x.grad.shape}") # Gradient sums to 1.0 per row because gradient of LSE is softmax print(f"Gradient row sums (should be ~1.0): {x.grad.sum(dim=1)}") # === Quantization example === weights = torch.randn(4, 4, requires_grad=True) quantized = StraightThroughEstimator.apply(weights, 8) loss_q = quantized.sum() loss_q.backward() print(f"\nQuantization gradient (straight-through, same shape as weights): {weights.grad.shape}") print(f"Gradient values (all 1.0 — passed straight through): {weights.grad.unique()}")
StableLogSumExp gradcheck: PASSED
Custom gradient shape: torch.Size([3, 10])
Gradient row sums (should be ~1.0): tensor([1.0000, 1.0000, 1.0000],
grad_fn=<SumBackward1>)
Quantization gradient (straight-through, same shape as weights): torch.Size([4, 4])
Gradient values (all 1.0 — passed straight through): tensor([1.])
- Numerical stability: when the Chain Rule gradient overflows or underflows for your specific operation — log-sum-exp, softmax with large logits, division by values that can approach zero
- Non-differentiable operations: when you need to define a surrogate gradient for operations that are mathematically non-differentiable — quantization, thresholding, straight-through estimators for binary neural networks
- Performance: when Autograd's generic gradient is significantly slower than a hand-optimized kernel — fused operations that combine multiple steps can be faster than the generic backward graph
- Memory optimization: when you want to recompute intermediates in
backward()rather than storing them inforward()— a form of manual gradient checkpointing for specific operations - Always verify with
torch.autograd.gradcheck()using float64 inputs — float32 is not precise enough for finite difference verification to be reliable
backward() that is mathematically wrong is significantly harder to debug than a failing test — the model will train, loss will decrease, and incorrect gradients will be invisible until you compare against a numerically estimated baseline.backward() — storing them as instance attributes will cause reference counting issues.torch.autograd.gradcheck() using float64 inputs before deploying. A wrong custom gradient trains silently and surfaces only in model quality metrics.Memory Optimization — Gradient Checkpointing and Efficient Autograd
The computational graph stores intermediate activations from every operation in the forward pass — because backward() needs them to compute gradients. For a 24-layer transformer with a 512 sequence length and batch size 32, this can easily consume 30 to 40GB of GPU memory just for the graph's intermediate activations, independent of the model weights themselves.
Gradient checkpointing is the standard solution. Rather than storing all intermediates during the forward pass, torch.utils.checkpoint discards them and recomputes them during the backward pass from the nearest checkpoint boundary. The memory saved is proportional to how much of the graph is checkpointed. The compute cost is approximately 30% additional forward pass operations because each checkpointed segment runs twice — once during forward and once during backward reconstruction.
The 30% compute overhead sounds significant, but consider the alternative: if your model does not fit on one GPU without checkpointing, you either add GPUs (expensive and complex) or reduce batch size (hurts convergence). In most cases, 30% more compute on the same hardware is the correct trade-off.
The practical checkpointing strategy: checkpoint expensive blocks and skip cheap ones. Transformer attention blocks and feed-forward networks are the expensive ones — they dominate both memory and compute. Embedding lookups, layer norms, and linear projections at the input/output are cheap to store and expensive to recompute relative to the memory they consume. Checkpoint the transformer blocks, leave the rest.
PyTorch 2.0 introduced use_reentrant=False for torch.utils.checkpoint, which is strictly better than the old default. The old reentrant implementation has edge cases with double backward (gradient of gradients) and with custom autograd functions. Always use use_reentrant=False on PyTorch 2.0+.
# io.thecodeforge: Gradient checkpointing for memory-efficient large model training # This pattern reduces peak GPU memory for a 24-layer transformer # from ~40GB to ~12GB at the cost of ~30% more compute import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint class TransformerBlock(nn.Module): """ A standard transformer block: attention + FFN with residual connections. The expensive memory consumer: attention intermediate activations (sequence_len x sequence_len matrices) and FFN hidden states (4x expansion). """ def __init__(self, d_model=512, nhead=8): super().__init__() self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) self.ffn = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x): # Without checkpointing: all of attn_out, ffn_out are stored for backward attn_out, _ = self.attn(x, x, x) x = self.norm1(x + attn_out) ffn_out = self.ffn(x) x = self.norm2(x + ffn_out) return x class CheckpointedTransformer(nn.Module): """ Transformer with gradient checkpointing on expensive blocks. Memory profile comparison (12 layers, d_model=512, seq_len=512, batch=32): - Without checkpointing: ~24GB peak GPU memory - With checkpointing: ~8GB peak GPU memory - Compute overhead: ~30% more forward pass time """ def __init__(self, n_layers=12, d_model=512): super().__init__() # Embedding is cheap to store — do not checkpoint it self.embedding = nn.Embedding(50000, d_model) self.layers = nn.ModuleList([ TransformerBlock(d_model) for _ in range(n_layers) ]) # Output projection is also cheap — do not checkpoint self.output_proj = nn.Linear(d_model, 50000) def forward(self, x): # Embedding — cheap, store the intermediate x = self.embedding(x) for layer in self.layers: # checkpoint() discards all intermediates inside TransformerBlock.forward() # During backward, it re-executes the block to reconstruct them # use_reentrant=False: use the newer, safer implementation (PyTorch 2.0+) # The old default (use_reentrant=True) has edge cases with double backward x = checkpoint(layer, x, use_reentrant=False) # Output projection — cheap, store the intermediate return self.output_proj(x) # Demonstrate memory measurement model = CheckpointedTransformer(n_layers=12, d_model=512) # Use integer token IDs for the embedding layer batch = torch.randint(0, 50000, (4, 64)) # batch_size=4, seq_len=64 tokens torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None out = model(batch) loss = out.sum() loss.backward() if torch.cuda.is_available(): peak_mb = torch.cuda.max_memory_allocated() / 1024**2 print(f"Peak GPU memory: {peak_mb:.1f} MB") print(f"Output shape: {out.shape}") print(f"Embedding gradient exists: {model.embedding.weight.grad is not None}") print(f"Layer 0 attention gradient exists: {model.layers[0].attn.in_proj_weight.grad is not None}") print(f"All parameters received gradients: " f"{all(p.grad is not None for p in model.parameters())}")
Embedding gradient exists: True
Layer 0 attention gradient exists: True
All parameters received gradients: True
- Checkpoint transformer attention blocks and FFN layers — they dominate both memory and compute, making the recomputation trade-off favorable
- Do not checkpoint embedding layers or final projection layers — they are cheap to store and expensive to recompute relative to their memory footprint
- Always use use_reentrant=False on PyTorch 2.0+ — the old default has documented edge cases with double backward and custom autograd functions
- Expect approximately 30% more compute time for approximately 60% less peak memory — this trade-off is almost always worth it when the alternative is adding GPU hardware
- Do not checkpoint operations with side effects — print statements, logging calls, and assertion checks inside a checkpointed block will execute twice during backward
- Measure with
torch.cuda.max_memory_allocated()before and after to verify actual savings — the theoretical reduction and the measured reduction can differ based on where you place checkpoint boundaries
torch.no_grad(), .detach(), and Inference Mode
When you do not need gradients — during inference, validation, or metric computation — explicitly disabling Autograd is not optional. Leaving gradient tracking enabled during inference wastes memory (the entire computational graph is built for every forward pass) and compute (version counters are incremented for every operation). At production serving scale, this overhead compounds into meaningful latency and cost.
PyTorch provides three mechanisms with different performance characteristics and different safety guarantees.
torch.no_grad() disables gradient computation — the backward engine will not run for operations inside the context. However, it still increments version counters on tensor operations, which is the mechanism PyTorch uses to detect in-place modifications. This makes no_grad() safe to use during training-adjacent code like gradient clipping where you might mix tracked and untracked operations.
torch.inference_mode() (introduced in PyTorch 1.9) goes further. It disables both gradient computation and version counter tracking. Tensors created inside inference_mode() are permanently marked as non-version-tracked — you cannot call .backward() on them even after exiting the context. The performance gain over no_grad() is 10 to 20% for typical model inference, which compounds significantly at scale.
.detach() operates on a single tensor rather than a context. It creates a new tensor that shares the same storage (no copy) but is not connected to the computational graph. Use it when you need to pass a specific tensor to a non-PyTorch library, log a tensor value, or break a gradient flow path in the middle of a model. Unlike no_grad(), detach() works at tensor granularity rather than context granularity.
The production recommendation is clear: use inference_mode() for serving, no_grad() for validation loops during training, and .detach() for individual tensor operations that need to escape the graph.
# io.thecodeforge: Comparing inference contexts — performance and safety import torch import torch.nn as nn import time model = nn.Sequential( nn.Linear(1024, 2048), nn.ReLU(), nn.Linear(2048, 2048), nn.ReLU(), nn.Linear(2048, 1024), nn.ReLU(), nn.Linear(1024, 10) ) x = torch.randn(256, 1024) N_ITER = 1000 # === Benchmark 1: Default (graph tracking enabled) === # Builds the full computational graph for every forward pass # Memory: graph + all intermediates allocated start = time.perf_counter() for _ in range(N_ITER): _ = model(x) default_time = time.perf_counter() - start # === Benchmark 2: torch.no_grad() === # Disables gradient computation but still increments version counters # Appropriate for validation during training (safe to mix with tracked ops) start = time.perf_counter() for _ in range(N_ITER): with torch.no_grad(): _ = model(x) no_grad_time = time.perf_counter() - start # === Benchmark 3: torch.inference_mode() === # Disables gradient computation AND version tracking # Maximum performance for pure inference — use for production serving start = time.perf_counter() for _ in range(N_ITER): with torch.inference_mode(): _ = model(x) inference_time = time.perf_counter() - start print(f"Default (graph tracking): {default_time:.3f}s (baseline)") print(f"no_grad(): {no_grad_time:.3f}s ({(1 - no_grad_time/default_time)*100:.1f}% faster)") print(f"inference_mode(): {inference_time:.3f}s ({(1 - inference_time/default_time)*100:.1f}% faster)") # === Correct usage patterns === # Pattern 1: Production inference — always use inference_mode with torch.inference_mode(): predictions = model(x) # predictions.requires_grad is False — cannot call backward on them # This is correct for serving: predictions are outputs, not training intermediates # Pattern 2: Validation during training — no_grad is appropriate # because you are still in a training context and may mix with gradient operations loss = torch.tensor(0.5, requires_grad=True) # example tracked tensor with torch.no_grad(): val_output = model(x) # no gradient tracked for this # loss.backward() still works here — no_grad only affected val_output # Pattern 3: .detach() for logging a specific tensor training_loss = torch.tensor(1.23, requires_grad=True) # has grad_fn loss_for_log = training_loss.detach() # same value, no graph reference loss_scalar = training_loss.item() # Python float, no tensor at all print(f"\nTraining loss requires_grad: {training_loss.requires_grad}") print(f"Detached loss requires_grad: {loss_for_log.requires_grad}") print(f".item() returns Python type: {type(loss_scalar)}") # Pattern 4: .detach() to break gradient flow within a model # Use when you want one branch to not affect gradients of shared layers encoder_output = torch.randn(8, 128, requires_grad=True) stop_gradient = encoder_output.detach() # downstream ops see value but no gradient print(f"\nStop-gradient tensor requires_grad: {stop_gradient.requires_grad}")
no_grad(): 1.823s (25.6% faster)
inference_mode(): 1.607s (34.4% faster)
Training loss requires_grad: True
Detached loss requires_grad: False
.item() returns Python type: <class 'float'>
Stop-gradient tensor requires_grad: False
- torch.no_grad(): disables gradient computation, keeps version tracking — use for validation during training where you mix tracked and untracked operations in the same scope
- torch.inference_mode(): disables gradients + version tracking — use for production serving and evaluation where no backward pass will ever be needed
- .detach(): breaks graph reference for one specific tensor — use for logging, visualization, or stop-gradient operations within a model
- .item(): extracts the scalar value as a Python float — use for logging scalar losses and metrics. No tensor, no graph, no ambiguity.
- Never use
inference_mode()on tensors you plan to backpropagate through — they are permanently marked as non-tracked andbackward()will fail
no_grad() depending on model architecture and hardware. For a model serving 1000 requests per second, that is potentially hundreds of milliseconds of latency improvement per second of wall clock time — meaningful at production scale.no_grad() does not eliminate is real and measurable. Every tensor operation inside no_grad() still increments an integer counter. For models with hundreds of operations per forward pass and thousands of requests per second, this overhead adds up.inference_mode(). Validation during training uses no_grad(). Individual tensor operations that need to escape the graph use .detach(). These are not interchangeable — pick the right tool for each context.torch.inference_mode() for production inference — not no_grad(). It disables both gradient computation and version tracking, providing 10 to 34% faster execution for serving workloads.torch.no_grad() for validation during training when you need to mix tracked and untracked operations in adjacent code.| Feature | Manual Differentiation | PyTorch Autograd |
|---|---|---|
| Graph Construction | Static — pre-derived equations fixed at development time | Dynamic (Define-by-Run) — graph built during each forward pass, reflecting exact execution |
| Complexity Handling | Exponentially difficult — each architectural change requires re-deriving all gradients | Automatic via Chain Rule — gradient functions update automatically when architecture changes |
| Conditional Logic Support | Requires special framework ops (cond, scan) — Python if/else cannot be used directly | Native Python if/else and loops work directly — graph records whichever branch executed |
| Code Maintenance | Must rewrite gradient math for every architectural change — error-prone and slow | Gradient math updates automatically — change the forward pass and backward follows |
| Numerical Stability | Can be hand-optimized per operation for maximum stability | Correct by default; override with torch.autograd.Function when stability is insufficient |
| Inference Performance | N/A — no graph overhead | Use inference_mode() to eliminate graph overhead entirely — 10-34% faster than default |
🎯 Key Takeaways
- Autograd automates the Chain Rule by building a dynamic computational graph during the forward pass and traversing it in reverse during
backward(). Understanding the graph lifecycle — construction, retention, destruction — is the prerequisite for debugging memory leaks and gradient errors. - Always call
optimizer.zero_grad()before the forward pass. PyTorch accumulates gradients by default. Forgetting this produces exponential gradient growth that silently corrupts models over dozens of epochs before the divergence becomes visible in loss metrics. - Use .item() for scalar logging and .detach().cpu() for tensor logging. Every tensor stored without detaching holds a reference to its entire computational graph, causing linear GPU memory growth that terminates in OOM.
- Use
torch.inference_mode()for production serving — it is 10 to 34% faster thanno_grad()because it disables both gradient computation and version counter tracking. Reserveno_grad()for validation during training where tracked and untracked operations coexist. - Gradient checkpointing trades approximately 30% more compute for approximately 60% less peak GPU memory by recomputing intermediates during backward instead of storing them during forward. Checkpoint expensive transformer blocks, skip cheap embedding layers.
- Implement torch.autograd.Function with a custom
backward()when Autograd's default gradient is numerically unstable — log operations near zero, softmax with large logits, or division by small values. Always verify custom gradients withtorch.autograd.gradcheck()using float64 inputs before deploying. - Gradient clipping (torch.nn.utils.clip_grad_norm_) is not optional for production training. Treat max_norm=1.0 as a starting point. It prevents a single outlier batch from causing catastrophic weight updates that take multiple epochs to detect.
⚠ Common Mistakes to Avoid
Interview Questions on This Topic
- QWhat is a computational graph and how does PyTorch's dynamic graph differ from static graph frameworks?Mid-levelReveal
- QWhy do we need to call
optimizer.zero_grad()in a PyTorch training loop?JuniorReveal - QWhat does the requires_grad=True flag signify when creating a tensor, and when should you set it to False?JuniorReveal
- QExplain the role of the
backward()function in the context of the Chain Rule.Mid-levelReveal - QWhat is the difference between
torch.no_grad()andtorch.inference_mode(), and which should you use for production serving?SeniorReveal
Frequently Asked Questions
What is Autograd in PyTorch?
Autograd is PyTorch's automatic differentiation engine. It records every operation performed on tensors marked with requires_grad=True in a dynamic computational graph during the forward pass. When you call .backward() on a scalar output — typically the training loss — Autograd traverses this graph in reverse, applying the Chain Rule at each operation node to compute gradients with respect to every tracked tensor. The result is stored in each tensor's .grad attribute, ready for the optimizer to use for parameter updates. The key property: because the graph is built during execution rather than pre-compiled, it always reflects the exact computation that ran, including any conditional branches.
What is the difference between .backward() and torch.autograd.grad()?
.backward() computes gradients and stores them in the .grad attribute of each leaf tensor. It modifies tensor state in-place. This is the correct tool for training loops where you call optimizer.step() immediately after — the optimizer reads .grad attributes directly.
torch.autograd.grad() computes gradients and returns them as new tensors without modifying any .grad attributes. This is the correct tool when you need gradients as values to compute on — for example, computing Jacobians, Hessians (gradient of gradients), meta-learning algorithms that differentiate through the optimization process, or any case where you need the gradient tensor itself as an input to further computation rather than just for a weight update.
Why does PyTorch accumulate gradients instead of replacing them?
Gradient accumulation is a deliberate design decision that enables two important production patterns.
First, simulating large batch sizes with limited GPU memory: by accumulating gradients over N micro-batches before calling optimizer.step(), you can train with an effective batch size of N times your GPU's capacity. This is standard practice for large language model training where the effective batch size is thousands of samples but each GPU can only hold dozens.
Second, recurrent architectures with parameter sharing: RNNs use the same parameters at every time step, and gradients from all time steps must accumulate at those shared parameters before the weight update.
The trade-off is that standard training loops must explicitly call optimizer.zero_grad() to reset gradients between steps. PyTorch chose accumulation as the default because it makes the important patterns correct by default, at the cost of requiring an explicit reset for the common case.
How do I verify my custom gradient implementation is correct?
Use torch.autograd.gradcheck() which numerically estimates gradients using finite differences and compares them against your custom backward() output. The critical detail: use float64 inputs, not float32. Float32 does not have sufficient precision for the finite difference comparison to be reliable at typical epsilon values — gradcheck will produce false failures or false passes.
Basic usage: result = torch.autograd.gradcheck(your_function, (input_float64,), eps=1e-6, atol=1e-4). A True return means your gradient matches numerical estimation within tolerance.
If your custom function will be used in second-order optimization — computing gradients of gradients, as in MAML or natural gradient methods — also run torch.autograd.gradgradcheck() to verify that the backward() itself is differentiable. A function that passes gradcheck but fails gradgradcheck has a non-differentiable backward pass, which breaks second-order algorithms silently.
What happens if I call .backward() twice on the same computation?
By default, the computational graph is destroyed after the first .backward() call. This happens because PyTorch frees the graph immediately after backward() completes to reclaim the GPU memory used by intermediate activations. A second .backward() call raises RuntimeError: Trying to backward through the graph a second time — the Function objects that the graph traversal would call have already been freed.
If you need to call backward() multiple times on the same graph — for example, computing gradients of multiple scalar outputs from a multi-task model separately, or debugging gradient flow — pass retain_graph=True to every backward() call except the final one: loss1.backward(retain_graph=True); loss2.backward(). The graph is retained in memory until the final call frees it.
Be aware that retain_graph=True causes gradients to accumulate — if you call backward() twice, .grad attributes contain the sum of both backward passes, not just the last one. Reset gradients with optimizer.zero_grad() or tensor.grad = None before each backward() if you need clean gradients per call.
Developer and founder of TheCodeForge. I built this site because I was tired of tutorials that explain what to type without explaining why it works. Every article here is written to make concepts actually click.