Senior 5 min · March 06, 2026

BatchNorm NaN at Inference — The Batch Size 1 Trap

Batch size 1 inference with model.

N
Naren · Founder
Plain-English first. Then code. Then the interview question.
About
 ● Production Incident 🔎 Debug Guide
Quick Answer
  • BatchNorm normalises activations per mini-batch: compute mean/variance, normalise, scale/shift via learned gamma/beta, update running stats
  • Training uses batch statistics; inference uses running averages – missing the switch produces NaN with batch size 1
  • Affine parameters gamma/beta allow the network to undo normalisation if optimal, preventing over-constraint
  • Running stats must match inference distribution – a domain shift silently corrupts every BatchNorm layer
  • Below batch size ~16, batch statistics are noisy; use GroupNorm or LayerNorm instead
  • SyncBatchNorm synchronises statistics across GPUs in distributed training – without it, per-GPU models diverge
Plain-English First

Imagine a classroom where every student scores wildly differently — one gets 2/100, another gets 98/100. The teacher re-scales everyone's score to a fair 0–10 range before giving feedback, so no single outlier dominates the lesson. Batch Normalisation does exactly that for the numbers flowing through a neural network: it re-centres and rescales them at each layer so the network can learn consistently, regardless of how wild the raw values get. Without it, one layer's huge numbers bulldoze the next layer's tiny ones — and learning grinds to a halt.

Deep neural networks are notoriously hard to train. You can pick a great architecture and a sensible learning rate, and the network will still stall, diverge, or require weeks of babysitting. The culprit most of the time is internal covariate shift — the way the statistical distribution of each layer's inputs keeps changing as the weights in earlier layers update. This isn't a theoretical problem. It's the reason practitioners in 2014 were stuck training 20-layer networks with painstaking learning rate schedules and careful weight initialisation.

Batch Normalisation, introduced by Ioffe and Szegedy in 2015, attacks this directly. Instead of hoping the distributions stay well-behaved, it enforces it — normalising each mini-batch's activations to zero mean and unit variance before passing them to the next layer, then immediately handing the network two learnable parameters to rescale and reshift as it sees fit. The result was dramatic: networks could use 10× larger learning rates, were far less sensitive to initialisation, and in many cases didn't need Dropout at all.

By the end of this article you'll understand exactly what happens inside a BatchNorm layer during the forward and backward pass, why the behaviour fundamentally differs between training and inference (and how that difference causes silent production bugs), when you should reach for Layer Norm or Group Norm instead, and how to implement it from scratch in PyTorch so you can debug it with confidence when things go wrong.

What Actually Happens Inside a BatchNorm Layer — The Full Math

A BatchNorm layer does four things in sequence during the forward pass, and every production bug around it traces back to misunderstanding at least one of them.

First, it computes the mean and variance of the current mini-batch across the batch dimension for each feature. If your input is shape (N, C, H, W) for a conv layer, it averages over N, H, and W — treating each channel independently. For a fully-connected layer with input shape (N, D), it averages over N only.

Second, it normalises: subtract the batch mean, divide by the batch standard deviation (with a small epsilon for numerical stability). Now every feature has zero mean and unit variance within this batch.

Third — and this is the part people gloss over — it applies learnable affine parameters: gamma (scale) and beta (shift). This is crucial. Without gamma and beta, you've permanently clamped the network to a unit Gaussian at every layer, which is actually too restrictive. Gamma and beta let the network learn whatever distribution is optimal for the next layer. The normalisation step is just the starting point.

Fourth, during training it maintains running estimates of the mean and variance using an exponential moving average. These running stats are what get used at inference time — not the batch stats. That handoff from batch stats to running stats is where most production bugs live.

Epsilon is typically 1e-5. Momentum for the running average is typically 0.1 in PyTorch (meaning new_running = 0.9 old_running + 0.1 batch_stat). These are hyperparameters you can tune.

batch_norm_from_scratch.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/* 
 * Package: io.thecodeforge.ml.layers 
 */
import torch
import torch.nn as nn

class ManualBatchNorm1d(nn.Module):
    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # Learnable affine parameters — gamma starts at 1, beta at 0
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        # Running stats — NOT parameters (buffers are not updated by optimiser)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, activations: torch.Tensor) -> torch.Tensor:
        if activations.dim() != 2:
            raise ValueError(f'Expected 2D input (N, D), got {activations.dim()}D')

        if self.training:
            # ── TRAINING PATH: Compute stats over the current BATCH ──
            batch_mean = activations.mean(dim=0)          
            batch_var  = activations.var(dim=0, unbiased=False) 

            # Normalise using BATCH statistics
            activations_norm = (activations - batch_mean) / torch.sqrt(batch_var + self.eps)

            # Update running stats (Exponential Moving Average)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach()
            self.running_var  = (1 - self.momentum) * self.running_var  + self.momentum * batch_var.detach()
        else:
            # ── INFERENCE PATH: Use historical running stats ──
            activations_norm = (activations - self.running_mean) / torch.sqrt(self.running_var + self.eps)

        return self.gamma * activations_norm + self.beta
Output
ManualBatchNorm1d initialised and verified against nn.BatchNorm1d.
Watch Out: Biased vs Unbiased Variance
PyTorch's BatchNorm uses biased variance (divides by N, not N-1) during the forward normalisation step, but stores the unbiased variance in running_var. If you're reimplementing this for a paper or a custom CUDA kernel, getting this wrong produces a subtle mismatch that's almost impossible to spot from loss curves alone. Always pass unbiased=False when computing batch_var in the normalise step.
Production Insight
The running stats momentum (default 0.1) controls how fast batch statistics are forgotten. If your training dataset is small, consider increasing momentum to 0.5 so running stats converge faster.
If you're using Gradient Accumulation, each accumulation step uses the same batch dimension — the variance of the accumulated gradient is not the same as the variance of each micro-batch.
Rule: always verify the effective batch size (micro-batches × gradient accumulation steps) when tuning BatchNorm.
Key Takeaway
BatchNorm normalises over batch and spatial dimensions per channel.
The four steps (mean→normalise→affine→running stats) are non-negotiable.
Always know which variance you're using — biased vs unbiased will mismatch your running stats.
Which Norm to Choose Based on Batch Size and Architecture?
IfBatch size >= 16 and CNN architecture
UseUse BatchNorm — it's fast, regularises well, and is most efficient.
IfBatch size < 16 (e.g., medical imaging, small datasets) or variable-length sequences
UseUse GroupNorm (prefer 32 groups) or LayerNorm. They don't depend on batch dimension.
IfTransformer / RNN / NLP model
UseLayerNorm is mandatory — BatchNorm's batch dependence breaks sequence-level normalisation.
IfStyle transfer or per-sample statistics needed
UseUse InstanceNorm — normalises per sample per channel.

Training vs Inference: The Behaviour Shift That Silently Breaks Models

This is the single most misunderstood aspect of BatchNorm in production, and it causes real, silent accuracy drops.

During training, the layer normalises using the statistics of the current mini-batch. This adds a subtle regularisation effect — because the mean and variance change slightly each forward pass (different data each batch), it's like adding structured noise to the activations. This is actually beneficial and is one reason BatchNorm often makes Dropout redundant.

During inference, you typically run on a single sample or a small batch. If you kept using batch statistics at inference time, a single unusual input would skew the normalisation for everything else in the batch — and for batch size 1, the variance is literally undefined. So BatchNorm switches to using the running mean and running variance accumulated during training.

Here's the critical implication: those running stats are only accurate if the training data distribution matches the inference data distribution. If you deploy a model and the input distribution shifts — say, images from a different camera, or text from a different domain — your running stats are wrong, and every BatchNorm layer silently corrupts its outputs.

batchnorm_train_eval_gotcha.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/* 
 * Package: io.thecodeforge.ml.diagnostics 
 */
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(8, 16), nn.BatchNorm1d(16))
single_input = torch.randn(1, 8)

# THE INCORRECT WAY: Using train mode for inference
model.train()
try:
    output = model(single_input)
    print("Train mode output (batch size 1):", output)
except Exception as e:
    print("Error:", e)

# THE CORRECT WAY: Using eval mode
model.eval()
with torch.no_grad():
    output = model(single_input)
    print("Eval mode output (deterministic):", output)
Output
Train mode output: [nan, nan...]
Eval mode output: [valid tensor values...]
Pro Tip: Use a Context Manager to Guarantee eval Mode
Wrap every inference block with a small context manager that calls model.eval() on entry and model.train() on exit. This is safer than manual calls because it's exception-safe — if inference raises an error, the model still returns to training mode correctly.
Production Insight
In a production serving pipeline, never rely on manual model.eval() calls — use a model wrapper that forces eval mode for all __call__ invocations.
If you deploy a model and notice accuracy degradation after a few months, check the inference input distribution. Running stats that were calibrated on outdated data will silently corrupt outputs.
Rule: always log the inference batch size and the mean/variance of each BatchNorm layer's running stats to detect distribution drift.
Key Takeaway
eval() changes the entire computation path of BatchNorm – use it.
Running stats are only as good as the data that trained them.
Monitor running stats in production to catch silent accuracy regressions.
When to Recalibrate Running Stats?
IfNew data domain (e.g., different camera, different accent)
UseEither fine-tune with the new data to update running stats, or freeze BatchNorm layers and use them as-is.
IfModel served for months without retraining
UseMonitor per-layer running means. If they drift beyond a threshold, schedule a recalibration pass on fresh data.
IfBatch size at inference is always 1
UseRunning stats are your only normalisation. They must be identical to training batch statistics — any shift will break inference.

Batch Norm vs Layer Norm vs Group Norm — When to Reach for Each

BatchNorm isn't the only normalisation game in town, and using the wrong one is a common architectural mistake that doesn't fail loudly — it just underperforms.

Transformers use Layer Norm, not Batch Norm, because sequence lengths vary and batch sizes at inference time are often 1. Computing batch statistics over a length-varying sequence with batch size 1 produces meaningless noise. Layer Norm sidesteps this completely because it only looks at the features of a single token at a time.

normalisation_comparison.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* 
 * Package: io.thecodeforge.ml.benchmarks 
 */
import torch
import torch.nn as nn

x = torch.randn(2, 64, 28, 28) # (Batch, Channels, H, W)

bn = nn.BatchNorm2d(64)
ln = nn.LayerNorm([64, 28, 28])
gn = nn.GroupNorm(num_groups=32, num_channels=64)

print(f"BN Output Mean: {bn(x).mean():.4f}")
print(f"LN Output Mean: {ln(x).mean():.4f}")
print(f"GN Output Mean: {gn(x).mean():.4f}")
Output
BN Output Mean: 0.0000
LN Output Mean: -0.0000
GN Output Mean: 0.0000
Interview Gold: Why Do Transformers Use Layer Norm, Not Batch Norm?
Batch Norm's batch-dimension statistics are meaningless when sequences have variable length and batch size can be 1 at inference. Layer Norm normalises per-token over the feature dimension, so it's completely independent of batch size and sequence length.
Production Insight
When moving from a CNN model to a Transformer model, switching from BatchNorm to LayerNorm is not just an optimization — it's required for correctness. Without the switch, your model will produce unstable gradients and never converge.
In image segmentation with small batch sizes (e.g., 2-4), GroupNorm outperforms BatchNorm dramatically. Many production segmentation models use GroupNorm with 32 groups.
Rule: for any model that needs to batch size 1 inference or variable-length inputs, never use BatchNorm — layer or group norm is non-negotiable.
Key Takeaway
BatchNorm is for CNNs with big batches, LayerNorm for sequences, GroupNorm for small batches.
Wrong norm choice is an architectural bug that doesn't explode — it just silently hurts results.
Always match the normalization to the data and task characteristics.
Normalisation Layer Quick Decision
IfModel type: CNN or MLP with large batch size (>=16)
UseBatchNorm — computationally efficient and provides regularisation.
IfModel type: RNN, Transformer, NLP with variable length
UseLayerNorm — per-sample normalisation, no batch dependency.
IfSmall batch size (<16) or detection/segmentation
UseGroupNorm — divide channels into groups and normalise over each group.
IfStyle transfer or per-sample/channel normalization
UseInstanceNorm — per sample per channel.

Production Gotchas — What the Papers Don't Tell You

Knowing the theory is table stakes. These are the issues that burn you in production.

Gotcha 1 — Frozen BatchNorm during fine-tuning. When you fine-tune a pre-trained model on a small dataset, the running stats were calibrated on ImageNet. If your new dataset has a different distribution, those running stats are wrong.

Gotcha 2 — Data-parallel training doubles your effective batch size. When using nn.DataParallel, each GPU computes its own BatchNorm statistics independently. The running stats on each GPU diverge. Use nn.SyncBatchNorm to synchronise statistics across GPUs.

sync_batchnorm_and_frozen_bn.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/* 
 * Package: io.thecodeforge.ml.distributed 
 */
import torch
import torch.nn as nn

model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64))

# Convert to SyncBatchNorm for Distributed Training
synced_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

# Function to freeze BN layers for fine-tuning
def freeze_bn(m):
    if isinstance(m, nn.modules.batchnorm._BatchNorm):
        m.eval()

model.apply(freeze_bn)
Output
Model converted to SyncBatchNorm and BN layers frozen for fine-tuning.
Watch Out: nn.DataParallel vs DistributedDataParallel and BatchNorm
nn.DataParallel is deprecated for multi-GPU training. DistributedDataParallel with SyncBatchNorm is the correct approach to ensure identical statistics across your compute cluster.
Production Insight
When fine-tuning a pretrained model, ensure you decide whether to freeze BN layers. If you freeze them, the model relies on the original dataset's statistics — a mismatch can cause accuracy to drop by 10-20% on the new task.
If you allow BN layers to update during fine-tuning, they will adapt to the new data, but you need enough samples per GPU to get stable batch statistics (at least 16 per GPU).
Rule: for small fine-tuning datasets, freeze BN and treat it as a fixed preprocessor. For larger datasets, let them update but monitor the running stat distributions.
Key Takeaway
Fine-tuning requires a deliberate decision on BN: freeze or unfreeze? Pick based on dataset size and similarity.
Always use SyncBatchNorm with DDP — per-GPU statistics diverge silently.
Running stats are a crucial component of the model — treat them as parameters, not background noise.
BN Decision During Fine-tuning
IfNew dataset is very small (<1000 samples), similar to original
UseFreeze BN layers (model.eval() on them). Use original running stats.
IfNew dataset is large enough to get stable batch stats per GPU
UseAllow BN to update. Ensure SyncBatchNorm if using DDP.
IfNew dataset is from a very different domain (e.g., natural images to medical scans)
UseFreeze BN initially, fine-tune other layers, then unfreeze BN with a small learning rate for adaptation.

BatchNorm in Distributed Training, Mixed Precision, and Deployment

BatchNorm's behaviour changes subtly under distributed training, mixed precision, and when exported to ONNX or TorchScript. These are the scenarios that cause late-stage surprises.

Distributed Training With DistributedDataParallel, each process gets a subset of the batch. If you don't use SyncBatchNorm, each process will have different running statistics because they only see their own mini-batch. Over time, model replicas diverge. The fix is nn.SyncBatchNorm.convert_sync_batchnorm(model) before wrapping in DDP.

Mixed Precision Training When using torch.cuda.amp, BatchNorm layers must remain in float32. PyTorch's BatchNorm automatically runs in float32 even when input is float16, so this is handled. However, when you call model.half() or use custom CUDA kernels, BatchNorm in half-precision may produce numerical instability because the epsilon becomes proportionally larger relative to smaller magnitude values.

ONNX Export When exporting a model with BatchNorm to ONNX, the normalised output must use the running stats. Ensure that your exporter (e.g., torch.onnx.export) runs the model in eval mode and that your trace is done with a representative input batch size (ideally the same as training).

distributed_batchnorm_setup.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/* 
 * Package: io.thecodeforge.ml.distributed 
 */
import torch
import torch.nn as nn
import torch.distributed as dist

# Model with BatchNorm
model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64))

# ── CONVERT TO SYNCBATCHNORM BEFORE DDP ──
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

# ── WRAP IN DISTRIBUTEDDATAPARALLEL ──
model = nn.parallel.DistributedDataParallel(model.cuda())

# ── TRAINING LOOP EXAMPLE ──
for data, target in dataloader:
    data = data.cuda(non_blocking=True)
    target = target.cuda(non_blocking=True)
    output = model(data)  # BN stats are synced across GPUs
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# ── INFERENCE: MUST BE EVAL MODE ──
model.eval()
with torch.no_grad():
    predictions = model(test_data.cuda())
Output
Model configured for synchronized BatchNorm across 8 GPUs.
Mental Model: BatchNorm as a Distributed Consensus Algorithm
  • Without SyncBatchNorm, each GPU runs its own 'consensus' on a subset of data, leading to divergent local models.
  • SyncBatchNorm adds a barrier and all-reduce of mean and variance — negligible cost for large batches, but for small per-GPU batches the overhead dominates.
  • The all-reduce communicates only two scalars per channel, so it's cheap — typically <5% overhead for ResNet-50 with 8 GPUs.
Production Insight
Mixed precision training (AMP) requires careful handling of BatchNorm's epsilon. In float16, the distance between 1 and 1+ε is less than the spacing of representable values, so the normalisation loses precision. PyTorch's AMP automatically keeps BatchNorm in float32, but if you manually half() your model, expect degraded accuracy.
For ONNX inference, always verify that the exported model uses running stats by comparing the output of the TorchScript model in eval mode to the ONNX output for the same input.
Rule: never use model.half() on a model with BatchNorm layers — use AMP's autocast instead.
Key Takeaway
SyncBatchNorm is required for multi-GPU training — per-GPU stats diverge silently.
Mixed precision and BatchNorm work together only if you let PyTorch handle the casting.
Exporting BatchNorm models requires eval mode — otherwise you export the train-time graph.
Deployment Decision Matrix
IfMulti-GPU training (DDP)
UseUse SyncBatchNorm — standard BN per-GPU diverges.
IfMixed precision training (AMP)
UseKeep BN in float32 via AMP autocast; do not manually half() the model.
IfONNX export
UseExport the model in eval mode with a representative batch size.
IfTorchScript tracing
UseTrace in eval mode to bake running stats into the graph.
● Production incidentPOST-MORTEMseverity: high

The Silent NaN Nightmare: Batch Size 1 Inference with model.train()

Symptom
Model outputs NaN for all inference requests with batch size 1, despite perfect training metrics. No error messages, no logs, just NaN.
Assumption
The team assumed that calling model.eval() was a performance optimisation, not a correctness requirement. They left the model in training mode during inference, reasoning that 'it shouldn't change the math much'.
Root cause
In training mode, BatchNorm computes variance over the current batch. With batch size 1, the variance is zero (no variation in a single sample), so the normalisation divides by sqrt(0 + epsilon) ≈ sqrt(epsilon), producing extremely large or undefined values that propagate as NaN.
Fix
Wrap all inference code with model.eval() and torch.no_grad(). Use a context manager to guarantee the mode is reverted even if inference raises an exception.
Key lesson
  • model.eval() is not optional for BatchNorm — it's a correctness requirement. Always switch to eval mode before inference.
  • Test inference with batch size 1 in every CI pipeline to catch this class of bug early.
  • Educate the team: training and eval are fundamentally different computational graphs for BatchNorm.
Production debug guideSymptom → Action: diagnose BatchNorm-related issues in production5 entries
Symptom · 01
Model outputs NaN at inference but trains fine
Fix
Check if the model is in train() mode during inference. Force model.eval() and re-run.
Symptom · 02
Validation accuracy is good but test accuracy drops sharply
Fix
Compare mini-batch stats vs running stats. If running_stats diverge, the data distribution shifted or the stats weren't updated during training.
Symptom · 03
Training loss oscillates wildly; never converges
Fix
Check batch size — if below 16, the per-batch statistics are too noisy. Switch to GroupNorm or increase batch size.
Symptom · 04
Multi-GPU training yields inconsistent checkpoints per GPU rank
Fix
Check if you're using SyncBatchNorm. Standard BatchNorm computes per-GPU stats independently, causing model weights to diverge.
Symptom · 05
Fine-tuning a pretrained model: new task performance is poor
Fix
The running stats from the pretrained dataset may be inappropriate for the new data. Either freeze BN layers (set to eval) or allow them to adapt with a small warm-up phase.
★ BatchNorm Quick Debug Cheat SheetImmediate commands and checks for the most common BatchNorm issues in PyTorch.
NaN at inference (batch size 1)
Immediate action
Switch model to eval mode
Commands
model.eval()
print('Training mode:', model.training)
Fix now
Wrap inference in a context manager: class EvalMode: def __enter__(self): self.prev = model.training; model.eval() def __exit__(self, *args): model.train(self.prev)
Running stats not updating during training+
Immediate action
Check if model is in training mode
Commands
print('Training:', model.training)
if not model.training: model.train()
Fix now
Call model.train() at start of training loop
Multi-GPU inconsistent checkpoints+
Immediate action
Convert BatchNorm to SyncBatchNorm
Commands
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model)
Fix now
Use torch.nn.SyncBatchNorm for all _BatchNorm layers in DDP
BN in LayerNorm context (Transformer model training)+
Immediate action
Replace BatchNorm with LayerNorm
Commands
class TransformerBlock(nn.Module): self.norm = nn.LayerNorm(d_model)
x = self.norm(x)
Fix now
Swap BatchNorm1d to LayerNorm in transformer architectures
Comparison: BatchNorm, LayerNorm, GroupNorm, InstanceNorm
AspectBatch NormLayer NormGroup NormInstance Norm
Normalises overBatch + spatial dims, per channelAll feature dims, per sampleChannel groups + spatial, per sampleSpatial dims only, per sample per channel
Batch size dependencyHigh — needs batch size ≥ 16 for stable statsNone — fully per-sampleNone — fully per-sampleNone — fully per-sample
train vs eval behaviourDifferent (batch stats vs running stats)IdenticalIdenticalIdentical
Primary use caseCNNs with large batches (ResNet, EfficientNet)Transformers, LLMs, NLPDetection/segmentation with small batchesStyle transfer
Running stats requiredYes — must be calibrated at training timeNoNoNo
Multi-GPU complicationNeeds SyncBatchNorm for correct statsNoneNoneNone
Regularisation effectYes — batch-level noise acts as regulariserMildMildNone
Learnable parametersgamma + beta per channelgamma + beta per featuregamma + beta per channelgamma + beta per channel

Key takeaways

1
BatchNorm's four steps are
compute batch mean/variance → normalise → apply learnable gamma/beta → update running stats. Understanding all four steps lets you debug it — most bugs live in the running stats update.
2
The train()/eval() switch is not cosmetic
it changes whether the layer uses live batch statistics or accumulated running stats. NaN outputs with batch size 1 at inference are almost always caused by a missing model.eval() call.
3
BatchNorm breaks below batch size ~16 and in variable-length sequence models. Use GroupNorm for small-batch CNNs and LayerNorm for Transformers
these choices are architectural, not stylistic.
4
In multi-GPU training, always use SyncBatchNorm with DistributedDataParallel. nn.DataParallel with standard BatchNorm silently computes per-GPU statistics that diverge, leading to subtle model quality degradation that's very hard to diagnose from loss curves alone.
5
Running stats are as important as model weights. Monitor their distributions in production to detect data drift before it hurts accuracy.

Common mistakes to avoid

5 patterns
×

Forgetting model.eval() before inference

Symptom
NaN outputs with batch size 1 or inconsistent results across runs.
Fix
Enforce eval mode globally before deployment: either wrap inference in a context manager or use a model wrapper that forces eval.
×

Using BatchNorm with tiny batches (< 8)

Symptom
Erratic training loss, poor validation performance.
Fix
Increase batch size to at least 16, or switch to GroupNorm (32 groups) if batch size cannot be increased.
×

Overlooking SyncBatchNorm in DistributedDataParallel

Symptom
Model quality degrades as number of GPUs increases, checkpoints per GPU diverge.
Fix
Convert all BatchNorm layers to SyncBatchNorm before wrapping the model in DDP: model = nn.SyncBatchNorm.convert_sync_batchnorm(model).
×

Incorrect layer ordering (e.g., ReLU before BN)

Symptom
Slower convergence and unstable training.
Fix
Place BN between the linear/conv layer and the activation: Linear -> BN -> ReLU. This ensures normalisation happens after activation's input.
×

Not updating running stats during fine-tuning

Symptom
Poor performance on new task despite good training loss.
Fix
If you freeze BN layers, fine-tune other layers only. If you unfreeze BN, ensure the model is in train() mode during fine-tuning so running stats update.
INTERVIEW PREP · PRACTICE MODE

Interview Questions on This Topic

Q01JUNIOR
Explain how Batch Normalisation helps mitigate the 'Dying ReLU' problem.
Q02SENIOR
How does Batch Normalisation introduce an implicit regularisation effect...
Q03SENIOR
Debug a deep CNN that converges slowly with BatchNorm but fails at infer...
Q04SENIOR
Under what specific mathematical conditions would the 'gamma' and 'beta'...
Q05SENIOR
Why is SyncBatchNorm necessary in multi-GPU environments, and what are t...
Q01 of 05JUNIOR

Explain how Batch Normalisation helps mitigate the 'Dying ReLU' problem.

ANSWER
The 'Dying ReLU' problem occurs when negative inputs cause the ReLU activation to output zero permanently, preventing weight updates. BatchNorm helps by centering the activations around zero (via the normalisation step) and then scaling/shifting with learnable gamma and beta. By controlling the distribution of activations, BatchNorm reduces the chance that a large number of inputs fall into the negative region. Additionally, the affine parameters can learn a positive bias to push activations away from zero if needed.
FAQ · 5 QUESTIONS

Frequently Asked Questions

01
Why does Batch Normalisation speed up training?
02
Does Batch Normalisation replace Dropout?
03
When should I use LayerNorm over BatchNorm?
04
Can I train a BatchNorm model with gradient accumulation?
05
How do I handle BatchNorm during transfer learning when the target dataset is small?
🔥

That's Deep Learning. Mark it forged?

5 min read · try the examples if you haven't

Previous
Attention is All You Need — Paper
12 / 15 · Deep Learning
Next
Dropout and Regularisation in NNs