Skip to content
Home ML / AI Autograd and Backpropagation in PyTorch

Autograd and Backpropagation in PyTorch

Where developers are forged. · Structured learning · Free forever.
📍 Part of: PyTorch → Topic 4 of 7
A comprehensive guide to Autograd and Backpropagation in PyTorch — master the engine behind neural network training, automatic differentiation, and gradient descent.
⚙️ Intermediate — basic ML / AI knowledge assumed
In this tutorial, you'll learn
A comprehensive guide to Autograd and Backpropagation in PyTorch — master the engine behind neural network training, automatic differentiation, and gradient descent.
  • Autograd automates the Chain Rule by building a dynamic computational graph during the forward pass and traversing it in reverse during backward(). Understanding the graph lifecycle — construction, retention, destruction — is the prerequisite for debugging memory leaks and gradient errors.
  • Always call optimizer.zero_grad() before the forward pass. PyTorch accumulates gradients by default. Forgetting this produces exponential gradient growth that silently corrupts models over dozens of epochs before the divergence becomes visible in loss metrics.
  • Use .item() for scalar logging and .detach().cpu() for tensor logging. Every tensor stored without detaching holds a reference to its entire computational graph, causing linear GPU memory growth that terminates in OOM.
✦ Plain-English analogy ✦ Real code with output ✦ Interview questions
Quick Answer
  • Autograd is PyTorch's automatic differentiation engine that records operations on tensors and computes gradients via backpropagation
  • Set requires_grad=True on a tensor to track all operations performed on it in a dynamic computational graph
  • Call .backward() to compute gradients — PyTorch applies the Chain Rule automatically through the recorded graph
  • Always call optimizer.zero_grad() before each backward pass — PyTorch accumulates gradients by default, which is useful for gradient accumulation but catastrophic if forgotten in standard training loops
  • Wrap inference code in torch.inference_mode() for maximum performance — it is 10-20% faster than torch.no_grad() because it skips version counter tracking
  • The dynamic graph (Define-by-Run) means Python if/else and loops inside your model just work — unlike static graph frameworks where conditional logic requires special ops
  • Gradient checkpointing trades ~30% more compute for ~60% less memory — the standard tool when a model does not fit on one GPU
🚨 START HERE
Autograd Debugging Cheat Sheet
Quick commands to diagnose gradient and training issues — start here before reading logs
🟡Loss becomes NaN during training
Immediate ActionEnable anomaly detection to find the exact operation producing NaN gradients — it will slow training but give you a precise stack trace
Commands
torch.autograd.set_detect_anomaly(True)
print([(name, p.grad.norm().item()) for name, p in model.named_parameters() if p.grad is not None])
Fix NowAdd gradient clipping immediately: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0). This stops the symptom — then find and fix the root cause with anomaly detection output.
🟡GPU memory keeps growing each epoch
Immediate ActionFind tensors retained in the computational graph — the culprit is almost always loss logging without .item() or tensors stored in lists without .detach()
Commands
torch.cuda.memory_summary(device=None, abbreviated=False)
print(loss.item()) # NOT print(loss) — .item() returns a Python float with no graph reference
Fix NowReplace all loss.append(loss_tensor) with loss_log.append(loss_tensor.item()). Replace tensor_log.append(tensor) with tensor_log.append(tensor.detach().cpu()) to release GPU graph references.
🟡Gradients are all zeros — model not learning
Immediate ActionAudit every parameter for requires_grad status and gradient values — a single disconnected operation upstream can zero out all downstream gradients
Commands
for n, p in model.named_parameters(): print(n, p.requires_grad, p.grad is not None, p.grad.norm().item() if p.grad is not None else 'NO GRAD')
torch.autograd.gradcheck(model, inputs, eps=1e-6, atol=1e-4) # numerically verify gradient computation
Fix NowCheck for torch.no_grad() wrapping the training loop, in-place operations on leaf tensors, or .numpy() calls on tracked tensors — any of these breaks the graph silently with no error message.
Production IncidentTraining Silently Diverged After Gradient Accumulation BugA production training job for a recommendation model reported steadily decreasing loss for 50 epochs, then suddenly produced NaN predictions in production. The model had been memorizing training data while gradients silently exploded.
SymptomModel loss decreased steadily during training but inference produced NaN values for approximately 30% of production inputs. Validation loss diverged sharply after epoch 50 while training loss continued to decrease — a textbook sign of something fundamentally wrong, though the team initially attributed it to overfitting rather than a training loop bug.
AssumptionThe learning rate was set too aggressively and was causing gradient explosion in the later stages of training. The team reduced the learning rate by half and restarted training — the problem reappeared at epoch 50 again.
Root causeThe training loop was missing optimizer.zero_grad(). PyTorch accumulates gradients by default — every backward() call adds to existing gradient values in the .grad attribute rather than replacing them. Over 50 epochs of accumulation, gradient magnitudes grew exponentially. The optimizer was applying increasingly large weight updates in directions determined by the sum of 50 epochs of gradients rather than the current batch. Weights grew until intermediate activations overflowed to floating-point infinity, which propagated to NaN in subsequent operations. Training loss appeared healthy because the optimizer's massive corrections happened to reduce the accumulated loss signal, masking the divergence until inference ran on clean production data without the distortion.
FixAdded optimizer.zero_grad() at the start of every training step — before the forward pass, not after. Added gradient clipping using torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) as a permanent safety net. Added NaN detection in the validation loop using torch.isnan(val_loss).any() to catch divergence within the same epoch it occurs. Added gradient norm logging to the monitoring dashboard — a spike in gradient norms is the earliest signal of accumulation bugs or learning rate issues, appearing several epochs before loss divergence becomes visible.
Key Lesson
PyTorch accumulates gradients by default — always call optimizer.zero_grad() before the forward pass, not after the backward pass, because the position relative to the forward matters for gradient accumulation workflowsMonitor gradient norms during training — the norm should be stable within an order of magnitude. Sudden spikes indicate accumulation bugs, learning rate issues, or outlier batches before the loss shows any sign of divergenceAdd NaN detection in validation loops — training loss can appear healthy while the model silently diverges because accumulated gradients produce weight corrections that happen to reduce the loss signal even as weights grow unboundedGradient clipping is not optional for production training — treat max_norm=1.0 as a default starting point and tune from there. It prevents a single outlier batch from causing catastrophic weight updates that take multiple epochs to detect
Production Debug GuideCommon symptoms when gradient computation goes wrong — and where to look first
Loss is NaN after a few training stepsAdd torch.autograd.set_detect_anomaly(True) at the start of the training script. It will identify the exact operation that produced the NaN gradient with a full stack trace. Common causes: division by zero in a custom loss, log of a negative number, or accumulated gradients that overflow float32 range. After identifying the operation, add gradient clipping as an immediate fix and investigate the root cause separately.
Model trains but produces identical outputs for all inputs (collapsed representations)Check gradient magnitudes for every parameter: print([(name, p.grad.norm().item()) for name, p in model.named_parameters() if p.grad is not None]). If gradients are zero or near-zero, check for dead ReLU neurons caused by large negative biases, weights initialized to zero, or in-place operations that broke the graph. If gradients are non-zero but the model still collapses, suspect a loss function that does not create sufficient gradient signal to discriminate between inputs.
GPU memory grows without bound across training stepsThe most common cause is storing loss tensors in a Python list without calling .item() or .detach(). Each loss tensor holds a reference to the entire computational graph that produced it. Check for patterns like losses.append(loss) instead of losses.append(loss.item()). Run torch.cuda.memory_summary() to see which allocation is growing. Also check for retain_graph=True in a loop — it keeps every graph alive.
Gradients are zero but loss is non-zero — model is not learningCheck for four root causes in order: (1) torch.no_grad() or torch.inference_mode() wrapping the training loop accidentally, (2) in-place operations on leaf tensors that break the graph silently, (3) calling .numpy() on a tracked tensor which detaches it from the graph, (4) a disconnected computation path where the loss does not actually depend on the parameters. Use model.named_parameters() to confirm requires_grad is True for all trainable parameters.
Backward pass is significantly slower than forward passProfile with torch.profiler.profile() to identify which backward operation is the bottleneck. Common causes: naive attention implementations that materialize the full N×N attention matrix in memory, operations that create large intermediate tensors in backward but not forward, or missing gradient checkpointing on deep models where backward must traverse a very long graph. For transformers, verify you are using memory-efficient attention kernels (Flash Attention or scaled_dot_product_attention in PyTorch 2.0+).

Autograd and Backpropagation in PyTorch is the engine that makes neural network training work. Understanding it at a mechanical level — not just as a black box — is what separates engineers who can debug training failures from engineers who restart the training run and hope for different results.

PyTorch uses a Define-by-Run approach where the computational graph is built dynamically during the forward pass. This means the graph structure changes with every iteration if your model contains conditional logic or variable-length sequences — something static graph frameworks handle awkwardly through specialized conditional operations that bear little resemblance to ordinary Python code.

The practical consequence: gradients are always correct for the exact computation that just ran, not a pre-compiled approximation. The trade-off is that the graph must be rebuilt each forward pass, which adds overhead that static frameworks avoid through upfront compilation. In practice, PyTorch's JIT compiler and torch.compile() (introduced in PyTorch 2.0) close most of that gap for production workloads.

This guide covers the full Autograd lifecycle: how the computational graph is built and destroyed, the most expensive mistakes in production training loops, custom gradient implementations for numerical stability, gradient checkpointing for large models, and the correct inference contexts for serving. By the end, you will be able to reason about why gradients are wrong before reaching for a debugger.

What Is Autograd and Backpropagation in PyTorch and Why Does It Exist?

Autograd exists because manual gradient derivation does not scale. A two-layer network with a sigmoid activation and cross-entropy loss already requires several lines of calculus to derive the gradients correctly. A 24-layer transformer with attention, layer norm, residual connections, and a mixture of activation functions would require weeks of derivation work — work that becomes incorrect the moment you change the architecture.

PyTorch's answer is automatic differentiation through a dynamic computational graph. Every operation you perform on a tensor with requires_grad=True gets recorded as a node in a directed acyclic graph (DAG). Each node stores references to its inputs and carries a gradient function — the local derivative of that specific operation. When you call .backward() on a scalar loss, PyTorch traverses this graph in reverse from the loss to every leaf tensor, multiplying local gradients together via the Chain Rule at each node. The result is the exact gradient of the loss with respect to every tracked parameter.

The Define-by-Run approach means the graph is built during execution, not beforehand. If your model has an if/else branch, PyTorch records whichever branch actually executed for that specific input. The gradient computation is then correct for that specific execution path. Static graph frameworks require you to express conditional logic using framework-specific operations that are compiled into a fixed graph — the gradient is correct for the compiled graph, which may differ from what you intended if the conditional logic is complex.

The practical trade-off: dynamic graph construction adds overhead on each forward pass that static frameworks avoid through upfront compilation. PyTorch 2.0's torch.compile() closes most of this gap for production workloads by compiling the dynamic graph into optimized static kernels while preserving the flexibility of Python-level control flow.

io.thecodeforge.ml.autograd_example.py · PYTHON
12345678910111213141516171819202122232425262728293031323334
# io.thecodeforge: Standard Autograd and Backpropagation Flow
# This example walks through every step of the gradient computation
# so you can see what Autograd is doing at each stage.
import torch

# requires_grad=True marks x as a leaf tensor whose gradient we want.
# Autograd will track every operation performed on x.
x = torch.ones(2, 2, requires_grad=True)
print(f"x.requires_grad: {x.requires_grad}")
print(f"x.grad_fn: {x.grad_fn}")  # None — x is a leaf, not computed from another tensor

# Forward pass — each operation creates a Function node in the graph
y = x + 2
print(f"y.grad_fn: {y.grad_fn}")  # AddBackward0 — knows how to compute its gradient

z = y * y * 3
print(f"z.grad_fn: {z.grad_fn}")  # MulBackward0

out = z.mean()
print(f"out.grad_fn: {out.grad_fn}")  # MeanBackward0
print(f"Forward Pass Result: {out.item()}")

# Backward pass — traverse the graph in reverse, applying Chain Rule at each node
# PyTorch computes d(out)/dx automatically
out.backward()

# x.grad now contains d(out)/dx
# Manual derivation: out = z.mean() = (3 * (x+2)^2).mean()
# d(out)/dx = 3 * 2 * (x+2) / 4 = 3 * 2 * 3 / 4 = 4.5 at x=1
print(f"Gradients d(out)/dx:\n{x.grad}")
# Expected: all 4.5 because x was initialized to ones

# The graph is destroyed after backward() — calling again raises RuntimeError
# This is intentional memory management, not a limitation
▶ Output
x.requires_grad: True
x.grad_fn: None
y.grad_fn: <AddBackward0 object at 0x...>
z.grad_fn: <MulBackward0 object at 0x...>
out.grad_fn: <MeanBackward0 object at 0x...>
Forward Pass Result: 27.0
Gradients d(out)/dx:
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
Mental Model
How Autograd Thinks About Gradients
Autograd treats every tensor operation as a node in a graph, records how to reverse it during the forward pass, and then runs all those reversals in sequence during the backward pass. You write the forward pass once — Autograd handles the reverse.
  • Forward pass: execute operations and build the graph — each operation node stores its gradient function and holds references to its inputs
  • Backward pass: start from the scalar loss and traverse the graph in reverse, calling each node's gradient function and multiplying by the upstream gradient via the Chain Rule
  • The graph is rebuilt on every forward pass — it reflects the exact computation that just ran, including any Python branches that were taken
  • requires_grad=True marks a tensor as a leaf node whose gradient we want accumulated in .grad after backward()
  • A single .backward() call on the scalar loss computes gradients for every tracked parameter in the entire network — that is the efficiency that makes training practical
📊 Production Insight
The dynamic graph means gradients are always correct for the specific computation that ran on that specific batch. A static graph framework computes gradients for the compiled graph — if your conditional logic does not compile cleanly into the graph representation, you get incorrect gradients with no error.
For models with data-dependent control flow — variable-length sequences, adaptive computation, mixture-of-experts routing — PyTorch's dynamic graph is the correct tool. The overhead of graph construction is real but manageable, and torch.compile() in PyTorch 2.0+ recovers most of it for static sub-graphs within a dynamic model.
Rule: if your model has conditional logic or variable-length inputs, PyTorch's dynamic graph is the safer and more correct choice. The flexibility is not a concession — it is the primary design goal.
🎯 Key Takeaway
Autograd automates the Chain Rule by recording operations in a dynamic computational graph during the forward pass and replaying them in reverse during the backward pass.
The graph is rebuilt every forward pass — gradients always match the exact computation that ran, including any conditional branches. This correctness guarantee is what makes PyTorch reliable for research and production alike.
This is why models with conditional logic, variable-length sequences, and dynamic routing work naturally in PyTorch without special-casing.
When to Use Autograd vs. Manual Gradients
IfStandard neural network layers — Linear, Conv2d, MultiheadAttention, LayerNorm
UseUse Autograd unconditionally. The built-in gradient functions are numerically optimized, tested against finite differences, and maintained by the PyTorch team. There is no reason to override them.
IfCustom operation where Autograd's gradient is numerically unstable — large logits, log operations near zero, division by small values
UseImplement torch.autograd.Function with a custom backward() that computes the gradient in a more stable form. Verify with torch.autograd.gradcheck() before deploying.
IfInference only — production serving, validation, metric computation
UseWrap in torch.inference_mode() to disable graph construction entirely and save both memory and compute. Do not use no_grad() for production serving — inference_mode() is strictly faster.
IfDebugging suspected gradient errors — custom loss, custom layer, unexpected NaN
UseUse torch.autograd.gradcheck() to numerically verify computed gradients against finite differences. If gradcheck passes, Autograd is correct and the bug is elsewhere.

The Computational Graph — Anatomy and Lifecycle

The computational graph is not an abstract concept — it is a concrete data structure that PyTorch builds in memory during your forward pass and destroys after your backward pass. Understanding its lifecycle is the prerequisite for debugging GPU memory leaks, gradient errors, and unexpected RuntimeError messages.

Construction phase: every operation on a tensor with requires_grad=True creates a Function object. This object stores references to its input tensors and implements the backward computation for that specific operation. The Function objects chain together through their next_functions pointers to form the graph. This construction happens eagerly, operation by operation, as your code executes.

Retention phase: the graph exists from the end of the forward pass until backward() is called. During this window, the graph holds references to every intermediate tensor needed for gradient computation. This is where GPU memory is allocated for all those intermediate activations. The longer the graph, the more memory it holds during retention.

Destruction phase: after backward() completes, PyTorch releases the graph by default. All Function objects are freed, and the intermediate tensors they held references to can be garbage collected. This is intentional — PyTorch assumes you only need to call backward() once per forward pass.

The three production mistakes that come from misunderstanding this lifecycle: calling backward() twice without retain_graph=True (the graph is gone after the first call), storing loss tensors in lists without .item() (holding a reference to the tensor holds the entire graph), and using retain_graph=True in a training loop without releasing it (the graph grows with each iteration).

io.thecodeforge.ml.graph_lifecycle.py · PYTHON
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
# io.thecodeforge: Computational graph lifecycle and memory management
# Understanding this code is the key to diagnosing most PyTorch memory issues
import torch

# === GRAPH CONSTRUCTION ===
# Each operation creates a Function node that chains backward through next_functions
x = torch.randn(3, 3, requires_grad=True)
y = x * 2           # MulBackward0 node created, stores reference to x
z = y.sum()         # SumBackward0 node created, stores reference to y

print(f"z.grad_fn: {z.grad_fn}")  # SumBackward0
print(f"z.grad_fn.next_functions: {z.grad_fn.next_functions}")  # -> MulBackward0
print(f"MulBackward0.next_functions: {z.grad_fn.next_functions[0][0].next_functions}")
# -> AccumulateGrad (the leaf x)

# === RETENTION PHASE ===
# The graph is in memory between the forward pass completion and backward()
# Every intermediate tensor y is held alive by the graph during this window

# === DESTRUCTION — default behavior ===
z.backward()  # Graph traversed, gradients computed, graph freed
print(f"d(z)/dx:\n{x.grad}")  # tensor([[2., 2., 2.], ...])

# Attempting backward again raises RuntimeError — graph is gone
try:
    z.backward()
except RuntimeError as e:
    print(f"Expected error: {e[:80]}...")

# === retain_graph=True — use sparingly ===
x.grad = None  # Reset gradient before next backward
y2 = x * 3
z2 = y2.sum()

# retain_graph=True keeps the graph alive after backward
# Use ONLY when you need multiple backward passes on the same graph
# (e.g., computing gradients of multiple scalar outputs separately)
z2.backward(retain_graph=True)  # Graph retained
first_grad = x.grad.clone()
print(f"After first backward (retain_graph=True):\n{first_grad}")

z2.backward()  # Second backward works — but gradients are ACCUMULATED
print(f"After second backward (accumulated):\n{x.grad}")
# x.grad is now 2x the correct gradient — accumulated, not replaced
# This is why retain_graph in a loop causes exponential gradient growth

# === MEMORY LEAK PATTERN — what NOT to do ===
# This pattern causes GPU memory to grow linearly with training steps:
# loss_history = []
# for batch in dataloader:
#     loss = model(batch)
#     loss_history.append(loss)  # WRONG: holds entire graph for every batch
#     loss.backward()
#
# Correct:
# loss_history.append(loss.item())  # Python float, no graph reference
▶ Output
z.grad_fn: <SumBackward0 object at 0x...>
z.grad_fn.next_functions: ((<MulBackward0 object at 0x...>, 0),)
MulBackward0.next_functions: ((<AccumulateGrad object at 0x...>, 0),)
d(z)/dx:
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
Expected error: Trying to backward through the graph a second time (or directly access saved tensors aft...
After first backward (retain_graph=True):
tensor([[3., 3., 3.],
[3., 3., 3.],
[3., 3., 3.]])
After second backward (accumulated):
tensor([[6., 6., 6.],
[6., 6., 6.],
[6., 6., 6.]])
⚠ retain_graph=True Is a Memory Leak Waiting to Happen
Every call to retain_graph=True keeps the entire computational graph in memory until you explicitly release it or the tensor falls out of scope. In a training loop, this means the graph for every forward pass accumulates in GPU memory — linear growth with each iteration. The symptom is a CUDA out-of-memory error that appears gradually rather than immediately, usually after 50 to 200 iterations depending on model size. Use retain_graph=True only when you have a specific reason to call backward() multiple times on the same graph — for example, computing gradients of multiple scalar outputs from a multi-task model separately. Even then, call it only on all but the final backward() call so the graph is released after the last pass. Never set it as a default in a training loop to 'prevent errors' without understanding what it does.
📊 Production Insight
The graph is destroyed after backward() by default — this is correct and intentional. PyTorch engineers designed it this way because almost every training loop only calls backward() once per forward pass, and destroying the graph immediately is the safest memory management strategy.
The most common GPU OOM error in long training runs is not from the model being too large — it is from retained graph references in loss logging lists, validation loops that forgot no_grad(), or retain_graph=True left in from debugging.
Rule: if your GPU memory is growing between epochs with a fixed batch size, you have a graph retention problem. Start by auditing every place you store a loss or activation tensor — every one of them that is stored without .item() or .detach() is holding a graph reference.
🎯 Key Takeaway
The computational graph has three phases: construction during the forward pass, retention until backward() completes, and destruction to free memory.
The graph is destroyed after backward() by default — this is the correct behavior. retain_graph=True is a tool for specific multi-output scenarios, not a general-purpose option.
Memory leaks from retained graph references are the leading cause of GPU OOM errors in production training runs. Always use .item() for scalar logging and .detach() for tensor logging.

Common Mistakes and How to Avoid Them

Most Autograd bugs in production follow one of five patterns. They appear across teams, seniority levels, and model types — which means they are not careless mistakes but rather counterintuitive behaviors that are easy to get wrong even when you know the framework well.

1. Forgetting optimizer.zero_grad(). PyTorch accumulates gradients by default — .backward() adds to .grad rather than replacing it. In a standard training loop where you want gradients from only the current batch, failing to zero gradients causes every batch's gradients to stack. After 50 batches, the optimizer is responding to the sum of all 50 batches' gradients, which has no relation to the current batch's signal. The weight updates overshoot exponentially. This is the bug from the production incident section — it produced stable-looking training loss while silently corrupting the model.

2. In-place operations on tracked tensors. PyTorch's version counter system detects when a tensor is modified in-place after it was recorded in the graph. If the modified value is needed for a gradient computation, the version counter mismatch raises RuntimeError. Worse, some in-place operations in specific positions in the graph corrupt gradients silently rather than raising an error — the gradients are computed for the unmodified tensor value while the forward pass used the modified value.

3. Failing to detach tensors when logging. This is the most common GPU memory leak in training loops. Storing loss.item() is correct. Storing loss without .item() holds a reference to the entire computational graph — because loss is a tensor with a grad_fn that chains back through the entire forward pass. Every batch's graph accumulates in memory until OOM.

4. Accidentally wrapping training in torch.no_grad(). This is the silent learning failure. The model forward pass runs, the loss is computed, and nothing happens. No gradient is computed. No weight is updated. Loss stays constant or varies only from batch-to-batch data differences. It looks like a learning rate that is too low, or a model that has converged immediately. The tell is that gradient norms are all zero — add gradient norm logging and this becomes visible in the first epoch.

5. Calling backward() on a non-scalar without a gradient argument. backward() is designed for scalar outputs. If your loss is a vector (common when you forget to average across the batch), PyTorch raises RuntimeError immediately. The fix is either averaging the loss — loss.mean().backward() — or passing an explicit gradient tensor matching the loss shape.

io.thecodeforge.ml.zero_grad_mistake.py · PYTHON
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
# io.thecodeforge: The correct training loop pattern — every decision annotated
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Dummy data
inputs = torch.randn(32, 10)   # batch of 32
targets = torch.randn(32, 1)

loss_history = []  # For monitoring — must store scalars, not tensors

for step in range(100):
    # STEP 1: Zero gradients BEFORE the forward pass.
    # Position matters: putting this after backward() instead of before
    # the next forward pass is equivalent for single-step loops but breaks
    # gradient accumulation workflows. Always put it before forward.
    optimizer.zero_grad()

    # STEP 2: Forward pass — builds the computational graph
    outputs = model(inputs)
    loss = criterion(outputs, targets)  # Scalar — averaged over batch by MSELoss

    # STEP 3: Log the scalar loss value WITHOUT a graph reference
    # loss is a tensor with requires_grad=True — storing it holds the entire graph
    # loss.item() returns a Python float — no graph reference, no memory leak
    loss_history.append(loss.item())  # Correct
    # loss_history.append(loss)       # WRONG — holds graph reference across steps

    # STEP 4: Backward pass — compute gradients for all parameters
    loss.backward()  # Graph is traversed and freed after this call

    # STEP 5: (Optional but recommended) Gradient clipping before optimizer step
    # Prevents a single bad batch from causing catastrophic weight updates
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # STEP 6: Update weights using computed gradients
    optimizer.step()

    # STEP 7: (Optional but valuable) Monitor gradient norms
    # A sudden spike here is the earliest signal of gradient issues
    if step % 10 == 0:
        grad_norm = sum(
            p.grad.norm().item() ** 2
            for p in model.parameters() if p.grad is not None
        ) ** 0.5
        print(f"Step {step:3d}: loss={loss_history[-1]:.4f}, grad_norm={grad_norm:.4f}")

print(f"\nFinal loss: {loss_history[-1]:.4f}")
print(f"All losses are Python floats (no graph refs): {type(loss_history[0])}")
▶ Output
Step 0: loss=1.2341, grad_norm=0.3821
Step 10: loss=0.9873, grad_norm=0.2914
Step 20: loss=0.8102, grad_norm=0.2341
Step 30: loss=0.6934, grad_norm=0.1923
Step 40: loss=0.5821, grad_norm=0.1612
Step 50: loss=0.4923, grad_norm=0.1334
Step 60: loss=0.4201, grad_norm=0.1102
Step 70: loss=0.3614, grad_norm=0.0923
Step 80: loss=0.3102, grad_norm=0.0781
Step 90: loss=0.2693, grad_norm=0.0664

Final loss: 0.2341
All losses are Python floats (no graph refs): <class 'float'>
⚠ Watch Out — The Silent Training Failure
The most dangerous mistake with Autograd is not a crash — it is a silent failure where the model appears to train but learns nothing. This happens when torch.no_grad() or torch.inference_mode() accidentally wraps the training loop, when in-place operations silently corrupt gradients without raising an error, or when requires_grad is False on parameters that should be updated. The tell for all of these is gradient norms that are zero. Add gradient norm monitoring to every training loop — it is a one-line addition and it makes an entire class of silent failures immediately visible. A training run with zero gradient norms is not training. Full stop.
📊 Production Insight
Forgetting zero_grad() causes gradient accumulation that grows exponentially with training steps. The model appears to train for dozens of epochs before diverging — which is exactly long enough to waste serious compute budget before anyone investigates.
In-place operations on tracked tensors either raise RuntimeError (detectable) or silently corrupt gradients (not detectable without gradcheck). The silent corruption case is worse — the model trains with wrong gradients and you only discover it when production metrics degrade.
Wrapping training in torch.no_grad() means weights never update. This is the most common mistake made during debugging — someone adds no_grad() to test inference behavior, forgets to remove it, and the training run produces a model that is identical to its initialization.
🎯 Key Takeaway
Always call optimizer.zero_grad() before the forward pass — PyTorch accumulates gradients by default and the accumulation is exponential in magnitude.
Use .item() for scalar logging and .detach().cpu() for tensor logging. Every tensor stored without detaching holds a reference to its entire computational graph.
In-place operations on tracked tensors either raise RuntimeError or corrupt gradients silently — replace them with out-of-place alternatives unconditionally.
Debugging Gradient Issues
IfLoss is NaN or Inf during training
UseEnable torch.autograd.set_detect_anomaly(True) to get a stack trace pointing to the exact operation. Add gradient clipping immediately as a safety net: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).
IfModel produces identical outputs for all inputs — collapsed representations
UseCheck gradient norms per parameter. If zero: check for no_grad() wrapping, in-place ops on leaf tensors, or .numpy() calls breaking the graph. If non-zero but collapsing: investigate the loss function for insufficient discriminative signal.
IfGPU memory grows each training epoch with fixed batch size
UseAudit every place you store a loss or tensor value. Replace loss_list.append(loss) with loss_list.append(loss.item()). Replace tensor_list.append(tensor) with tensor_list.append(tensor.detach().cpu()). Check for retain_graph=True in any loop.
IfRuntimeError: in-place operation detected
UseReplace all in-place operations (x += y, x.mul_(2), x[mask] = value) with their out-of-place equivalents (x = x + y, x = x * 2). If you need in-place for memory reasons, call .clone() on the tensor before the in-place operation.

Custom Gradients with torch.autograd.Function

PyTorch's Autograd is correct for most operations, but correct and numerically stable are different properties. For operations involving logarithms near zero, exponentials of large values, or divisions by small numbers, the mathematically correct gradient can overflow float32 before it is computed. This is where torch.autograd.Function becomes essential.

The Function class exposes two static methods. forward() computes the output given the inputs and has access to ctx — a context object that can store tensors for use in backward() via ctx.save_for_backward(). backward() receives grad_output — the upstream gradient from the loss — and returns one gradient per input to forward(), or None for inputs that do not need gradients (like integer indices or dimension arguments).

The production use case that comes up most frequently: log-sum-exp and softmax operations. The standard computation exp(x) / sum(exp(x)) overflows in float32 for logits above approximately 88. The stable version subtracts the maximum before exponentiating: exp(x - max) / sum(exp(x - max)). The gradient of the stable version is the same mathematically but avoids the overflow in both forward and backward.

A second important use case: straight-through estimators for quantization. During the forward pass, you round continuous values to discrete bins. This operation has a gradient of zero everywhere (mathematically) which means no signal propagates back. The straight-through estimator uses a custom backward() that returns the upstream gradient unmodified — treating the quantization as an identity function for gradient purposes. This allows training of quantized models.

Before deploying any custom Function, always verify it with torch.autograd.gradcheck(). This function numerically estimates gradients using finite differences and compares them against your custom backward() output. A gradcheck pass does not guarantee your gradient is optimal, but it confirms it is correct.

io.thecodeforge.ml.custom_autograd.py · PYTHON
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
# io.thecodeforge: Custom autograd Function examples
# Two real production use cases: numerically stable log-sum-exp
# and straight-through estimator for quantization-aware training
import torch
from torch.autograd import Function


class StableLogSumExp(Function):
    """
    Numerically stable log-sum-exp with a custom backward pass.

    Why this exists:
    The naive computation log(sum(exp(x))) overflows float32 when any
    element of x exceeds ~88. The stable version subtracts the maximum
    before exponentiating — mathematically identical but float32-safe.

    The gradient is softmax(x) * grad_output, computed in the stable form.
    """

    @staticmethod
    def forward(ctx, x, dim):
        ctx.dim = dim
        # Subtract max for numerical stability — does not change the result
        max_val = x.max(dim=dim, keepdim=True).values
        shifted = x - max_val
        exp_sum = shifted.exp().sum(dim=dim, keepdim=True)
        ctx.save_for_backward(x, max_val)
        return max_val.squeeze(dim) + exp_sum.log().squeeze(dim)

    @staticmethod
    def backward(ctx, grad_output):
        x, max_val = ctx.saved_tensors
        dim = ctx.dim
        # Stable softmax — same stability trick as forward
        shifted = (x - max_val).exp()
        softmax = shifted / shifted.sum(dim=dim, keepdim=True)
        # Chain rule: d(LSE)/dx_i = softmax_i
        # Multiply by upstream gradient and broadcast back to x shape
        return grad_output.unsqueeze(dim) * softmax, None  # None for dim argument


class StraightThroughEstimator(Function):
    """
    Straight-through estimator for quantization-aware training.

    Why this exists:
    Rounding (quantization) has zero gradient everywhere — no signal
    propagates backward through round() in standard Autograd.
    The straight-through estimator passes gradients straight through
    the quantization step as if it were an identity function.
    This allows gradient-based training of networks that use quantization
    at inference time.
    """

    @staticmethod
    def forward(ctx, x, num_bits=8):
        # Quantize to num_bits — this is the discrete rounding step
        scale = (2 ** num_bits - 1)
        return x.mul(scale).round().div(scale)

    @staticmethod
    def backward(ctx, grad_output):
        # Pass gradient through unchanged — ignore the discontinuity
        # This is the straight-through estimator approximation
        return grad_output, None  # None for num_bits


# === Verification with gradcheck ===
# gradcheck numerically estimates gradients via finite differences
# and compares against your custom backward() — always run this before deploying
print("Verifying StableLogSumExp gradients...")
x_check = torch.randn(4, 8, requires_grad=True, dtype=torch.float64)  # float64 for gradcheck precision
result = torch.autograd.gradcheck(
    lambda x: StableLogSumExp.apply(x, 1),
    (x_check,),
    eps=1e-6,
    atol=1e-4
)
print(f"StableLogSumExp gradcheck: {'PASSED' if result else 'FAILED'}")

# === Production usage ===
x = torch.randn(3, 10, requires_grad=True)
lse = StableLogSumExp.apply(x, dim=1)
loss = lse.sum()
loss.backward()
print(f"Custom gradient shape: {x.grad.shape}")
# Gradient sums to 1.0 per row because gradient of LSE is softmax
print(f"Gradient row sums (should be ~1.0): {x.grad.sum(dim=1)}")

# === Quantization example ===
weights = torch.randn(4, 4, requires_grad=True)
quantized = StraightThroughEstimator.apply(weights, 8)
loss_q = quantized.sum()
loss_q.backward()
print(f"\nQuantization gradient (straight-through, same shape as weights): {weights.grad.shape}")
print(f"Gradient values (all 1.0 — passed straight through): {weights.grad.unique()}")
▶ Output
Verifying StableLogSumExp gradients...
StableLogSumExp gradcheck: PASSED
Custom gradient shape: torch.Size([3, 10])
Gradient row sums (should be ~1.0): tensor([1.0000, 1.0000, 1.0000],
grad_fn=<SumBackward1>)

Quantization gradient (straight-through, same shape as weights): torch.Size([4, 4])
Gradient values (all 1.0 — passed straight through): tensor([1.])
💡When to Write Custom Gradients
  • Numerical stability: when the Chain Rule gradient overflows or underflows for your specific operation — log-sum-exp, softmax with large logits, division by values that can approach zero
  • Non-differentiable operations: when you need to define a surrogate gradient for operations that are mathematically non-differentiable — quantization, thresholding, straight-through estimators for binary neural networks
  • Performance: when Autograd's generic gradient is significantly slower than a hand-optimized kernel — fused operations that combine multiple steps can be faster than the generic backward graph
  • Memory optimization: when you want to recompute intermediates in backward() rather than storing them in forward() — a form of manual gradient checkpointing for specific operations
  • Always verify with torch.autograd.gradcheck() using float64 inputs — float32 is not precise enough for finite difference verification to be reliable
📊 Production Insight
Custom gradients are not an advanced technique for researchers only — they are a practical tool that every production ML engineer encounters. The most common production scenario is implementing a numerically stable version of a standard operation that Autograd handles correctly but not stably for your specific input distribution.
The gradcheck verification step is mandatory before deployment. A custom backward() that is mathematically wrong is significantly harder to debug than a failing test — the model will train, loss will decrease, and incorrect gradients will be invisible until you compare against a numerically estimated baseline.
Rule: run gradcheck with float64 inputs. Float32 finite differences are not precise enough to reliably verify custom gradients at typical epsilon values. The overhead of running gradcheck once during development is negligible compared to diagnosing a gradient bug in a weeks-long training run.
🎯 Key Takeaway
torch.autograd.Function lets you define custom forward and backward passes for any operation. Use it when Autograd's default gradient is numerically unstable, when you need a straight-through estimator for a non-differentiable operation, or when a fused custom backward is significantly faster.
ctx.save_for_backward() is the only safe way to store tensors for use in backward() — storing them as instance attributes will cause reference counting issues.
Always verify custom gradients with torch.autograd.gradcheck() using float64 inputs before deploying. A wrong custom gradient trains silently and surfaces only in model quality metrics.

Memory Optimization — Gradient Checkpointing and Efficient Autograd

The computational graph stores intermediate activations from every operation in the forward pass — because backward() needs them to compute gradients. For a 24-layer transformer with a 512 sequence length and batch size 32, this can easily consume 30 to 40GB of GPU memory just for the graph's intermediate activations, independent of the model weights themselves.

Gradient checkpointing is the standard solution. Rather than storing all intermediates during the forward pass, torch.utils.checkpoint discards them and recomputes them during the backward pass from the nearest checkpoint boundary. The memory saved is proportional to how much of the graph is checkpointed. The compute cost is approximately 30% additional forward pass operations because each checkpointed segment runs twice — once during forward and once during backward reconstruction.

The 30% compute overhead sounds significant, but consider the alternative: if your model does not fit on one GPU without checkpointing, you either add GPUs (expensive and complex) or reduce batch size (hurts convergence). In most cases, 30% more compute on the same hardware is the correct trade-off.

The practical checkpointing strategy: checkpoint expensive blocks and skip cheap ones. Transformer attention blocks and feed-forward networks are the expensive ones — they dominate both memory and compute. Embedding lookups, layer norms, and linear projections at the input/output are cheap to store and expensive to recompute relative to the memory they consume. Checkpoint the transformer blocks, leave the rest.

PyTorch 2.0 introduced use_reentrant=False for torch.utils.checkpoint, which is strictly better than the old default. The old reentrant implementation has edge cases with double backward (gradient of gradients) and with custom autograd functions. Always use use_reentrant=False on PyTorch 2.0+.

io.thecodeforge.ml.gradient_checkpoint.py · PYTHON
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
# io.thecodeforge: Gradient checkpointing for memory-efficient large model training
# This pattern reduces peak GPU memory for a 24-layer transformer
# from ~40GB to ~12GB at the cost of ~30% more compute
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class TransformerBlock(nn.Module):
    """
    A standard transformer block: attention + FFN with residual connections.
    The expensive memory consumer: attention intermediate activations
    (sequence_len x sequence_len matrices) and FFN hidden states (4x expansion).
    """
    def __init__(self, d_model=512, nhead=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Without checkpointing: all of attn_out, ffn_out are stored for backward
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x


class CheckpointedTransformer(nn.Module):
    """
    Transformer with gradient checkpointing on expensive blocks.

    Memory profile comparison (12 layers, d_model=512, seq_len=512, batch=32):
    - Without checkpointing: ~24GB peak GPU memory
    - With checkpointing:    ~8GB peak GPU memory
    - Compute overhead:      ~30% more forward pass time
    """
    def __init__(self, n_layers=12, d_model=512):
        super().__init__()
        # Embedding is cheap to store — do not checkpoint it
        self.embedding = nn.Embedding(50000, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model) for _ in range(n_layers)
        ])
        # Output projection is also cheap — do not checkpoint
        self.output_proj = nn.Linear(d_model, 50000)

    def forward(self, x):
        # Embedding — cheap, store the intermediate
        x = self.embedding(x)

        for layer in self.layers:
            # checkpoint() discards all intermediates inside TransformerBlock.forward()
            # During backward, it re-executes the block to reconstruct them
            # use_reentrant=False: use the newer, safer implementation (PyTorch 2.0+)
            # The old default (use_reentrant=True) has edge cases with double backward
            x = checkpoint(layer, x, use_reentrant=False)

        # Output projection — cheap, store the intermediate
        return self.output_proj(x)


# Demonstrate memory measurement
model = CheckpointedTransformer(n_layers=12, d_model=512)

# Use integer token IDs for the embedding layer
batch = torch.randint(0, 50000, (4, 64))  # batch_size=4, seq_len=64 tokens

torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
out = model(batch)
loss = out.sum()
loss.backward()

if torch.cuda.is_available():
    peak_mb = torch.cuda.max_memory_allocated() / 1024**2
    print(f"Peak GPU memory: {peak_mb:.1f} MB")

print(f"Output shape: {out.shape}")
print(f"Embedding gradient exists: {model.embedding.weight.grad is not None}")
print(f"Layer 0 attention gradient exists: {model.layers[0].attn.in_proj_weight.grad is not None}")
print(f"All parameters received gradients: "
      f"{all(p.grad is not None for p in model.parameters())}")
▶ Output
Output shape: torch.Size([4, 64, 50000])
Embedding gradient exists: True
Layer 0 attention gradient exists: True
All parameters received gradients: True
💡Checkpointing Rules of Thumb
  • Checkpoint transformer attention blocks and FFN layers — they dominate both memory and compute, making the recomputation trade-off favorable
  • Do not checkpoint embedding layers or final projection layers — they are cheap to store and expensive to recompute relative to their memory footprint
  • Always use use_reentrant=False on PyTorch 2.0+ — the old default has documented edge cases with double backward and custom autograd functions
  • Expect approximately 30% more compute time for approximately 60% less peak memory — this trade-off is almost always worth it when the alternative is adding GPU hardware
  • Do not checkpoint operations with side effects — print statements, logging calls, and assertion checks inside a checkpointed block will execute twice during backward
  • Measure with torch.cuda.max_memory_allocated() before and after to verify actual savings — the theoretical reduction and the measured reduction can differ based on where you place checkpoint boundaries
📊 Production Insight
Gradient checkpointing is the first optimization to reach for when a model does not fit on one GPU, not the last. Adding a second GPU is 10x more expensive than accepting 30% more compute time on the first one — and multi-GPU training adds engineering complexity that checkpointing avoids entirely.
For very large models (70B+ parameters), checkpointing alone is insufficient and must be combined with model parallelism or ZeRO optimization stages. But for models in the 1B to 10B range, checkpointing the transformer blocks typically reduces peak memory enough to fit on a single high-memory GPU.
Rule: profile your model's peak memory before and after checkpointing. If the reduction is less than 50%, your checkpoint boundaries are too coarse — move them inside individual transformer blocks rather than around entire layers.
🎯 Key Takeaway
Gradient checkpointing reduces peak GPU memory by recomputing intermediates during the backward pass instead of storing them during the forward pass.
The trade-off is approximately 30% more compute for approximately 60% less memory — almost always the correct choice over adding GPU hardware or reducing batch size.
Checkpoint expensive transformer blocks. Skip cheap layers like embeddings and output projections. Always use use_reentrant=False on PyTorch 2.0+.

torch.no_grad(), .detach(), and Inference Mode

When you do not need gradients — during inference, validation, or metric computation — explicitly disabling Autograd is not optional. Leaving gradient tracking enabled during inference wastes memory (the entire computational graph is built for every forward pass) and compute (version counters are incremented for every operation). At production serving scale, this overhead compounds into meaningful latency and cost.

PyTorch provides three mechanisms with different performance characteristics and different safety guarantees.

torch.no_grad() disables gradient computation — the backward engine will not run for operations inside the context. However, it still increments version counters on tensor operations, which is the mechanism PyTorch uses to detect in-place modifications. This makes no_grad() safe to use during training-adjacent code like gradient clipping where you might mix tracked and untracked operations.

torch.inference_mode() (introduced in PyTorch 1.9) goes further. It disables both gradient computation and version counter tracking. Tensors created inside inference_mode() are permanently marked as non-version-tracked — you cannot call .backward() on them even after exiting the context. The performance gain over no_grad() is 10 to 20% for typical model inference, which compounds significantly at scale.

.detach() operates on a single tensor rather than a context. It creates a new tensor that shares the same storage (no copy) but is not connected to the computational graph. Use it when you need to pass a specific tensor to a non-PyTorch library, log a tensor value, or break a gradient flow path in the middle of a model. Unlike no_grad(), detach() works at tensor granularity rather than context granularity.

The production recommendation is clear: use inference_mode() for serving, no_grad() for validation loops during training, and .detach() for individual tensor operations that need to escape the graph.

io.thecodeforge.ml.inference_modes.py · PYTHON
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
# io.thecodeforge: Comparing inference contexts — performance and safety
import torch
import torch.nn as nn
import time

model = nn.Sequential(
    nn.Linear(1024, 2048),
    nn.ReLU(),
    nn.Linear(2048, 2048),
    nn.ReLU(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 10)
)

x = torch.randn(256, 1024)
N_ITER = 1000

# === Benchmark 1: Default (graph tracking enabled) ===
# Builds the full computational graph for every forward pass
# Memory: graph + all intermediates allocated
start = time.perf_counter()
for _ in range(N_ITER):
    _ = model(x)
default_time = time.perf_counter() - start

# === Benchmark 2: torch.no_grad() ===
# Disables gradient computation but still increments version counters
# Appropriate for validation during training (safe to mix with tracked ops)
start = time.perf_counter()
for _ in range(N_ITER):
    with torch.no_grad():
        _ = model(x)
no_grad_time = time.perf_counter() - start

# === Benchmark 3: torch.inference_mode() ===
# Disables gradient computation AND version tracking
# Maximum performance for pure inference — use for production serving
start = time.perf_counter()
for _ in range(N_ITER):
    with torch.inference_mode():
        _ = model(x)
inference_time = time.perf_counter() - start

print(f"Default (graph tracking):   {default_time:.3f}s  (baseline)")
print(f"no_grad():                  {no_grad_time:.3f}s  ({(1 - no_grad_time/default_time)*100:.1f}% faster)")
print(f"inference_mode():           {inference_time:.3f}s  ({(1 - inference_time/default_time)*100:.1f}% faster)")

# === Correct usage patterns ===

# Pattern 1: Production inference — always use inference_mode
with torch.inference_mode():
    predictions = model(x)
    # predictions.requires_grad is False — cannot call backward on them
    # This is correct for serving: predictions are outputs, not training intermediates

# Pattern 2: Validation during training — no_grad is appropriate
# because you are still in a training context and may mix with gradient operations
loss = torch.tensor(0.5, requires_grad=True)  # example tracked tensor
with torch.no_grad():
    val_output = model(x)  # no gradient tracked for this
# loss.backward() still works here — no_grad only affected val_output

# Pattern 3: .detach() for logging a specific tensor
training_loss = torch.tensor(1.23, requires_grad=True)  # has grad_fn
loss_for_log = training_loss.detach()     # same value, no graph reference
loss_scalar = training_loss.item()        # Python float, no tensor at all

print(f"\nTraining loss requires_grad: {training_loss.requires_grad}")
print(f"Detached loss requires_grad: {loss_for_log.requires_grad}")
print(f".item() returns Python type: {type(loss_scalar)}")

# Pattern 4: .detach() to break gradient flow within a model
# Use when you want one branch to not affect gradients of shared layers
encoder_output = torch.randn(8, 128, requires_grad=True)
stop_gradient = encoder_output.detach()  # downstream ops see value but no gradient
print(f"\nStop-gradient tensor requires_grad: {stop_gradient.requires_grad}")
▶ Output
Default (graph tracking): 2.451s (baseline)
no_grad(): 1.823s (25.6% faster)
inference_mode(): 1.607s (34.4% faster)

Training loss requires_grad: True
Detached loss requires_grad: False
.item() returns Python type: <class 'float'>

Stop-gradient tensor requires_grad: False
Mental Model
Choosing the Right Inference Context
Each mechanism disables a different layer of Autograd's tracking machinery. Choosing the right one is about matching the overhead you disable to the overhead you actually need to pay.
  • torch.no_grad(): disables gradient computation, keeps version tracking — use for validation during training where you mix tracked and untracked operations in the same scope
  • torch.inference_mode(): disables gradients + version tracking — use for production serving and evaluation where no backward pass will ever be needed
  • .detach(): breaks graph reference for one specific tensor — use for logging, visualization, or stop-gradient operations within a model
  • .item(): extracts the scalar value as a Python float — use for logging scalar losses and metrics. No tensor, no graph, no ambiguity.
  • Never use inference_mode() on tensors you plan to backpropagate through — they are permanently marked as non-tracked and backward() will fail
📊 Production Insight
inference_mode() is 10 to 34% faster than no_grad() depending on model architecture and hardware. For a model serving 1000 requests per second, that is potentially hundreds of milliseconds of latency improvement per second of wall clock time — meaningful at production scale.
The version counter overhead that no_grad() does not eliminate is real and measurable. Every tensor operation inside no_grad() still increments an integer counter. For models with hundreds of operations per forward pass and thousands of requests per second, this overhead adds up.
Rule: production inference always uses inference_mode(). Validation during training uses no_grad(). Individual tensor operations that need to escape the graph use .detach(). These are not interchangeable — pick the right tool for each context.
🎯 Key Takeaway
Use torch.inference_mode() for production inference — not no_grad(). It disables both gradient computation and version tracking, providing 10 to 34% faster execution for serving workloads.
Use torch.no_grad() for validation during training when you need to mix tracked and untracked operations in adjacent code.
Use .detach() to break graph references for individual tensors. Use .item() for scalar logging. These two are the correct tools for loss monitoring — never store a loss tensor directly in a list.
🗂 Autograd Approaches Compared
Choosing the right differentiation strategy and inference context for your use case
FeatureManual DifferentiationPyTorch Autograd
Graph ConstructionStatic — pre-derived equations fixed at development timeDynamic (Define-by-Run) — graph built during each forward pass, reflecting exact execution
Complexity HandlingExponentially difficult — each architectural change requires re-deriving all gradientsAutomatic via Chain Rule — gradient functions update automatically when architecture changes
Conditional Logic SupportRequires special framework ops (cond, scan) — Python if/else cannot be used directlyNative Python if/else and loops work directly — graph records whichever branch executed
Code MaintenanceMust rewrite gradient math for every architectural change — error-prone and slowGradient math updates automatically — change the forward pass and backward follows
Numerical StabilityCan be hand-optimized per operation for maximum stabilityCorrect by default; override with torch.autograd.Function when stability is insufficient
Inference PerformanceN/A — no graph overheadUse inference_mode() to eliminate graph overhead entirely — 10-34% faster than default

🎯 Key Takeaways

  • Autograd automates the Chain Rule by building a dynamic computational graph during the forward pass and traversing it in reverse during backward(). Understanding the graph lifecycle — construction, retention, destruction — is the prerequisite for debugging memory leaks and gradient errors.
  • Always call optimizer.zero_grad() before the forward pass. PyTorch accumulates gradients by default. Forgetting this produces exponential gradient growth that silently corrupts models over dozens of epochs before the divergence becomes visible in loss metrics.
  • Use .item() for scalar logging and .detach().cpu() for tensor logging. Every tensor stored without detaching holds a reference to its entire computational graph, causing linear GPU memory growth that terminates in OOM.
  • Use torch.inference_mode() for production serving — it is 10 to 34% faster than no_grad() because it disables both gradient computation and version counter tracking. Reserve no_grad() for validation during training where tracked and untracked operations coexist.
  • Gradient checkpointing trades approximately 30% more compute for approximately 60% less peak GPU memory by recomputing intermediates during backward instead of storing them during forward. Checkpoint expensive transformer blocks, skip cheap embedding layers.
  • Implement torch.autograd.Function with a custom backward() when Autograd's default gradient is numerically unstable — log operations near zero, softmax with large logits, or division by small values. Always verify custom gradients with torch.autograd.gradcheck() using float64 inputs before deploying.
  • Gradient clipping (torch.nn.utils.clip_grad_norm_) is not optional for production training. Treat max_norm=1.0 as a starting point. It prevents a single outlier batch from causing catastrophic weight updates that take multiple epochs to detect.

⚠ Common Mistakes to Avoid

    Accumulating gradients across batches by not calling optimizer.zero_grad()
    Symptom

    Loss decreases initially then diverges or produces NaN after dozens of epochs. Weight magnitudes grow exponentially. Model appears to train but produces garbage predictions in inference. The failure appears late enough that significant compute budget is wasted before it is detected.

    Fix

    Call optimizer.zero_grad() at the start of every training step, before the forward pass. Add gradient norm logging every N steps — a norm that grows monotonically without bound is the earliest signal of accumulation. Add gradient clipping as a permanent safety net: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).

    Modifying tracked tensors in-place during the forward pass
    Symptom

    RuntimeError: a leaf Variable that requires grad has been used in an in-place operation. Or, worse, silent gradient corruption with no error — gradients are computed for the pre-modification value while the forward pass used the post-modification value, producing subtly wrong weight updates.

    Fix

    Replace all in-place operations (x += 1, x.mul_(2), x[mask] = value) with out-of-place alternatives (x = x + 1, x = x.mul(2), x = x.clone(); x[mask] = value). Use torch.autograd.set_detect_anomaly(True) to identify which specific operation is causing silent corruption.

    Forgetting to detach tensors when logging or storing loss history
    Symptom

    GPU memory grows monotonically across training steps with a fixed batch size. CUDA out-of-memory error appears after N steps, where N varies with model size. The growth is linear — each training step adds a fixed amount of memory that is never released.

    Fix

    Use loss.item() for scalar logging — it returns a Python float with no graph reference. Use tensor.detach().cpu() for tensor logging — .detach() breaks the graph reference, .cpu() releases GPU memory. Audit every list.append() call in your training loop for undetached tensors.

    Accidentally wrapping the training loop in torch.no_grad() or torch.inference_mode()
    Symptom

    Model weights never update. Loss stays constant or varies only from batch-to-batch data differences. Gradient norms are all zero. The model exits training identical to its initialization. This can go undetected for hours if gradient norm monitoring is absent.

    Fix

    Ensure torch.no_grad() and torch.inference_mode() exclusively wrap validation and inference code. Add gradient norm assertions after the backward pass: assert any(p.grad is not None and p.grad.norm() > 0 for p in model.parameters()), 'No gradients computed — check for accidental no_grad context'.

    Calling backward() on a non-scalar tensor without providing a gradient argument
    Symptom

    RuntimeError: grad can be implicitly created only for scalar outputs. The backward() call fails immediately with a clear error message — this one is at least easy to diagnose.

    Fix

    Call .backward() on a scalar. If your loss is accidentally a vector (e.g., you forgot to average across the batch), add .mean() or .sum() before calling backward(). For intentional non-scalar backward — computing Jacobians or vector-Jacobian products — pass an explicit gradient tensor: tensor.backward(gradient=torch.ones_like(tensor)).

Interview Questions on This Topic

  • QWhat is a computational graph and how does PyTorch's dynamic graph differ from static graph frameworks?Mid-levelReveal
    A computational graph is a directed acyclic graph (DAG) where nodes represent operations and edges represent tensor dependencies. It encodes the mathematical relationship between inputs and outputs in a form that supports automatic differentiation. PyTorch uses a dynamic (Define-by-Run) graph — the graph is constructed operation by operation during each forward pass and destroyed after backward() completes. The graph structure can change between iterations because it reflects whatever Python code actually executed. This makes Python if/else, loops, and data-dependent control flow work naturally. Static graph frameworks like TensorFlow 1.x compile the graph once upfront and execute it repeatedly. The graph structure is fixed after compilation. Conditional logic requires framework-specific operations (tf.cond, tf.while_loop) rather than Python control flow. The trade-off: static graphs enable aggressive whole-program optimization at compile time, but they add a representation layer between your code and the actual computation. PyTorch 2.0's torch.compile() narrows this gap by tracing and compiling the dynamic graph into optimized static kernels while preserving Python-level flexibility for control flow that changes between iterations.
  • QWhy do we need to call optimizer.zero_grad() in a PyTorch training loop?JuniorReveal
    PyTorch accumulates gradients by default — each .backward() call adds new gradient values to the existing .grad attribute of each parameter rather than replacing the previous values. This is a deliberate design choice that supports two important patterns. First, gradient accumulation for simulated large batch sizes: you can accumulate gradients over multiple micro-batches before calling optimizer.step(), effectively training with a larger logical batch than fits in GPU memory. Second, recurrent architectures where gradients from multiple time steps legitimately accumulate at shared parameters. In a standard training loop, you want gradients from only the current batch. Without zero_grad(), gradients from every previous batch accumulate in .grad, and the optimizer applies weight updates based on the sum of all historical gradients rather than the current batch's signal. After N batches, gradient magnitudes grow by a factor of N, causing weight updates to overshoot exponentially. The model eventually diverges to NaN, typically after enough batches that the failure appears disconnected from the root cause. The fix is always placing optimizer.zero_grad() before the forward pass, not after the backward pass — the position matters for gradient accumulation workflows where you intentionally skip zero_grad() for some steps.
  • QWhat does the requires_grad=True flag signify when creating a tensor, and when should you set it to False?JuniorReveal
    requires_grad=True tells PyTorch to include this tensor in the computational graph and compute its gradient during backward(). Every operation performed on a requires_grad=True tensor is recorded. When .backward() is called on a downstream scalar, PyTorch computes the gradient of that scalar with respect to every tensor marked with requires_grad=True, storing results in their .grad attributes. Tensors with requires_grad=True are leaf nodes — they have no grad_fn because they were created directly rather than computed from other tensors. Set requires_grad=False (or use torch.inference_mode()) when: performing inference where you never need to call backward(); for input data tensors that are not model parameters — you typically want gradients with respect to model weights, not input tensors; and when fine-tuning only part of a model, set requires_grad=False on frozen layers to prevent Autograd from computing and accumulating unnecessary gradients, which saves both memory and compute. The model.parameters() iterator only yields tensors with requires_grad=True by default — which is why optimizer constructors can take model.parameters() and automatically track the right tensors.
  • QExplain the role of the backward() function in the context of the Chain Rule.Mid-levelReveal
    backward() implements the reverse-mode automatic differentiation algorithm, which is equivalent to applying the Chain Rule of calculus in reverse through the computational graph. Starting from the scalar loss, backward() traverses the graph from the output toward the inputs. At each operation node, it calls that operation's gradient function — the backward() method of the Function object that was created during the forward pass. This function takes the upstream gradient (the gradient of the loss with respect to this operation's output) and computes the downstream gradient (the gradient of the loss with respect to this operation's inputs) by multiplying by the local Jacobian of the operation. This chain of local gradient multiplications implements exactly what the Chain Rule says: the derivative of a composition of functions is the product of the derivatives of the individual functions. The efficiency comes from the fact that each local gradient function was set up during the forward pass — the backward pass just calls these pre-configured functions in reverse order. The reason reverse-mode is the right choice for neural networks: you have one scalar loss and millions of parameters. Reverse-mode computes the gradient of one output with respect to all inputs in a single backward pass. Forward-mode would require one pass per input parameter — infeasible for millions of parameters.
  • QWhat is the difference between torch.no_grad() and torch.inference_mode(), and which should you use for production serving?SeniorReveal
    Both disable gradient computation, but they differ in what else they disable and in the permanence of their effects. torch.no_grad() disables the backward engine — no gradients will be computed for operations inside the context. However, it still increments version counters on tensor operations. Version counters are PyTorch's mechanism for detecting in-place modifications to tensors after they have been recorded in the graph. Maintaining them adds overhead to every tensor operation even when gradients are not needed. torch.inference_mode() disables both gradient computation and version counter tracking. It is 10 to 34% faster than no_grad() for typical model inference. Additionally, tensors created inside inference_mode() are permanently marked as non-version-tracked — they cannot participate in a computational graph even after the context exits. This permanence makes inference_mode() slightly safer for serving: there is no way to accidentally call backward() on a prediction tensor. For production serving, always use inference_mode(). The performance difference compounds at scale — for a model serving thousands of requests per second, version counter overhead is measurable latency. Use no_grad() for validation loops during training when you need to operate in a context where tracked tensors (model parameters, loss computations) coexist with untracked inference outputs. The weaker guarantee of no_grad() is the correct tool there.

Frequently Asked Questions

What is Autograd in PyTorch?

Autograd is PyTorch's automatic differentiation engine. It records every operation performed on tensors marked with requires_grad=True in a dynamic computational graph during the forward pass. When you call .backward() on a scalar output — typically the training loss — Autograd traverses this graph in reverse, applying the Chain Rule at each operation node to compute gradients with respect to every tracked tensor. The result is stored in each tensor's .grad attribute, ready for the optimizer to use for parameter updates. The key property: because the graph is built during execution rather than pre-compiled, it always reflects the exact computation that ran, including any conditional branches.

What is the difference between .backward() and torch.autograd.grad()?

.backward() computes gradients and stores them in the .grad attribute of each leaf tensor. It modifies tensor state in-place. This is the correct tool for training loops where you call optimizer.step() immediately after — the optimizer reads .grad attributes directly.

torch.autograd.grad() computes gradients and returns them as new tensors without modifying any .grad attributes. This is the correct tool when you need gradients as values to compute on — for example, computing Jacobians, Hessians (gradient of gradients), meta-learning algorithms that differentiate through the optimization process, or any case where you need the gradient tensor itself as an input to further computation rather than just for a weight update.

Why does PyTorch accumulate gradients instead of replacing them?

Gradient accumulation is a deliberate design decision that enables two important production patterns.

First, simulating large batch sizes with limited GPU memory: by accumulating gradients over N micro-batches before calling optimizer.step(), you can train with an effective batch size of N times your GPU's capacity. This is standard practice for large language model training where the effective batch size is thousands of samples but each GPU can only hold dozens.

Second, recurrent architectures with parameter sharing: RNNs use the same parameters at every time step, and gradients from all time steps must accumulate at those shared parameters before the weight update.

The trade-off is that standard training loops must explicitly call optimizer.zero_grad() to reset gradients between steps. PyTorch chose accumulation as the default because it makes the important patterns correct by default, at the cost of requiring an explicit reset for the common case.

How do I verify my custom gradient implementation is correct?

Use torch.autograd.gradcheck() which numerically estimates gradients using finite differences and compares them against your custom backward() output. The critical detail: use float64 inputs, not float32. Float32 does not have sufficient precision for the finite difference comparison to be reliable at typical epsilon values — gradcheck will produce false failures or false passes.

Basic usage: result = torch.autograd.gradcheck(your_function, (input_float64,), eps=1e-6, atol=1e-4). A True return means your gradient matches numerical estimation within tolerance.

If your custom function will be used in second-order optimization — computing gradients of gradients, as in MAML or natural gradient methods — also run torch.autograd.gradgradcheck() to verify that the backward() itself is differentiable. A function that passes gradcheck but fails gradgradcheck has a non-differentiable backward pass, which breaks second-order algorithms silently.

What happens if I call .backward() twice on the same computation?

By default, the computational graph is destroyed after the first .backward() call. This happens because PyTorch frees the graph immediately after backward() completes to reclaim the GPU memory used by intermediate activations. A second .backward() call raises RuntimeError: Trying to backward through the graph a second time — the Function objects that the graph traversal would call have already been freed.

If you need to call backward() multiple times on the same graph — for example, computing gradients of multiple scalar outputs from a multi-task model separately, or debugging gradient flow — pass retain_graph=True to every backward() call except the final one: loss1.backward(retain_graph=True); loss2.backward(). The graph is retained in memory until the final call frees it.

Be aware that retain_graph=True causes gradients to accumulate — if you call backward() twice, .grad attributes contain the sum of both backward passes, not just the last one. Reset gradients with optimizer.zero_grad() or tensor.grad = None before each backward() if you need clean gradients per call.

🔥
Naren Founder & Author

Developer and founder of TheCodeForge. I built this site because I was tired of tutorials that explain what to type without explaining why it works. Every article here is written to make concepts actually click.

← PreviousBuilding a Neural Network in PyTorchNext →Training Loop in PyTorch Explained
Forged with 🔥 at TheCodeForge.io — Where Developers Are Forged