BatchNorm NaN at Inference — The Batch Size 1 Trap
Batch size 1 inference with model.
- BatchNorm normalises activations per mini-batch: compute mean/variance, normalise, scale/shift via learned gamma/beta, update running stats
- Training uses batch statistics; inference uses running averages – missing the switch produces NaN with batch size 1
- Affine parameters gamma/beta allow the network to undo normalisation if optimal, preventing over-constraint
- Running stats must match inference distribution – a domain shift silently corrupts every BatchNorm layer
- Below batch size ~16, batch statistics are noisy; use GroupNorm or LayerNorm instead
- SyncBatchNorm synchronises statistics across GPUs in distributed training – without it, per-GPU models diverge
Imagine a classroom where every student scores wildly differently — one gets 2/100, another gets 98/100. The teacher re-scales everyone's score to a fair 0–10 range before giving feedback, so no single outlier dominates the lesson. Batch Normalisation does exactly that for the numbers flowing through a neural network: it re-centres and rescales them at each layer so the network can learn consistently, regardless of how wild the raw values get. Without it, one layer's huge numbers bulldoze the next layer's tiny ones — and learning grinds to a halt.
Deep neural networks are notoriously hard to train. You can pick a great architecture and a sensible learning rate, and the network will still stall, diverge, or require weeks of babysitting. The culprit most of the time is internal covariate shift — the way the statistical distribution of each layer's inputs keeps changing as the weights in earlier layers update. This isn't a theoretical problem. It's the reason practitioners in 2014 were stuck training 20-layer networks with painstaking learning rate schedules and careful weight initialisation.
Batch Normalisation, introduced by Ioffe and Szegedy in 2015, attacks this directly. Instead of hoping the distributions stay well-behaved, it enforces it — normalising each mini-batch's activations to zero mean and unit variance before passing them to the next layer, then immediately handing the network two learnable parameters to rescale and reshift as it sees fit. The result was dramatic: networks could use 10× larger learning rates, were far less sensitive to initialisation, and in many cases didn't need Dropout at all.
By the end of this article you'll understand exactly what happens inside a BatchNorm layer during the forward and backward pass, why the behaviour fundamentally differs between training and inference (and how that difference causes silent production bugs), when you should reach for Layer Norm or Group Norm instead, and how to implement it from scratch in PyTorch so you can debug it with confidence when things go wrong.
What Actually Happens Inside a BatchNorm Layer — The Full Math
A BatchNorm layer does four things in sequence during the forward pass, and every production bug around it traces back to misunderstanding at least one of them.
First, it computes the mean and variance of the current mini-batch across the batch dimension for each feature. If your input is shape (N, C, H, W) for a conv layer, it averages over N, H, and W — treating each channel independently. For a fully-connected layer with input shape (N, D), it averages over N only.
Second, it normalises: subtract the batch mean, divide by the batch standard deviation (with a small epsilon for numerical stability). Now every feature has zero mean and unit variance within this batch.
Third — and this is the part people gloss over — it applies learnable affine parameters: gamma (scale) and beta (shift). This is crucial. Without gamma and beta, you've permanently clamped the network to a unit Gaussian at every layer, which is actually too restrictive. Gamma and beta let the network learn whatever distribution is optimal for the next layer. The normalisation step is just the starting point.
Fourth, during training it maintains running estimates of the mean and variance using an exponential moving average. These running stats are what get used at inference time — not the batch stats. That handoff from batch stats to running stats is where most production bugs live.
Epsilon is typically 1e-5. Momentum for the running average is typically 0.1 in PyTorch (meaning new_running = 0.9 old_running + 0.1 batch_stat). These are hyperparameters you can tune.
Training vs Inference: The Behaviour Shift That Silently Breaks Models
This is the single most misunderstood aspect of BatchNorm in production, and it causes real, silent accuracy drops.
During training, the layer normalises using the statistics of the current mini-batch. This adds a subtle regularisation effect — because the mean and variance change slightly each forward pass (different data each batch), it's like adding structured noise to the activations. This is actually beneficial and is one reason BatchNorm often makes Dropout redundant.
During inference, you typically run on a single sample or a small batch. If you kept using batch statistics at inference time, a single unusual input would skew the normalisation for everything else in the batch — and for batch size 1, the variance is literally undefined. So BatchNorm switches to using the running mean and running variance accumulated during training.
Here's the critical implication: those running stats are only accurate if the training data distribution matches the inference data distribution. If you deploy a model and the input distribution shifts — say, images from a different camera, or text from a different domain — your running stats are wrong, and every BatchNorm layer silently corrupts its outputs.
model.eval() on entry and model.train() on exit. This is safer than manual calls because it's exception-safe — if inference raises an error, the model still returns to training mode correctly.model.eval() calls — use a model wrapper that forces eval mode for all __call__ invocations.Batch Norm vs Layer Norm vs Group Norm — When to Reach for Each
BatchNorm isn't the only normalisation game in town, and using the wrong one is a common architectural mistake that doesn't fail loudly — it just underperforms.
Transformers use Layer Norm, not Batch Norm, because sequence lengths vary and batch sizes at inference time are often 1. Computing batch statistics over a length-varying sequence with batch size 1 produces meaningless noise. Layer Norm sidesteps this completely because it only looks at the features of a single token at a time.
Production Gotchas — What the Papers Don't Tell You
Knowing the theory is table stakes. These are the issues that burn you in production.
Gotcha 1 — Frozen BatchNorm during fine-tuning. When you fine-tune a pre-trained model on a small dataset, the running stats were calibrated on ImageNet. If your new dataset has a different distribution, those running stats are wrong.
Gotcha 2 — Data-parallel training doubles your effective batch size. When using nn.DataParallel, each GPU computes its own BatchNorm statistics independently. The running stats on each GPU diverge. Use nn.SyncBatchNorm to synchronise statistics across GPUs.
model.eval() on them). Use original running stats.BatchNorm in Distributed Training, Mixed Precision, and Deployment
BatchNorm's behaviour changes subtly under distributed training, mixed precision, and when exported to ONNX or TorchScript. These are the scenarios that cause late-stage surprises.
Distributed Training With DistributedDataParallel, each process gets a subset of the batch. If you don't use SyncBatchNorm, each process will have different running statistics because they only see their own mini-batch. Over time, model replicas diverge. The fix is nn.SyncBatchNorm.convert_sync_batchnorm(model) before wrapping in DDP.
Mixed Precision Training When using torch.cuda.amp, BatchNorm layers must remain in float32. PyTorch's BatchNorm automatically runs in float32 even when input is float16, so this is handled. However, when you call or use custom CUDA kernels, BatchNorm in half-precision may produce numerical instability because the epsilon becomes proportionally larger relative to smaller magnitude values.model.half()
ONNX Export When exporting a model with BatchNorm to ONNX, the normalised output must use the running stats. Ensure that your exporter (e.g., torch.onnx.export) runs the model in eval mode and that your trace is done with a representative input batch size (ideally the same as training).
- Without SyncBatchNorm, each GPU runs its own 'consensus' on a subset of data, leading to divergent local models.
- SyncBatchNorm adds a barrier and all-reduce of mean and variance — negligible cost for large batches, but for small per-GPU batches the overhead dominates.
- The all-reduce communicates only two scalars per channel, so it's cheap — typically <5% overhead for ResNet-50 with 8 GPUs.
half() your model, expect degraded accuracy.model.half() on a model with BatchNorm layers — use AMP's autocast instead.half() the model.The Silent NaN Nightmare: Batch Size 1 Inference with model.train()
model.eval() was a performance optimisation, not a correctness requirement. They left the model in training mode during inference, reasoning that 'it shouldn't change the math much'.model.eval() and torch.no_grad(). Use a context manager to guarantee the mode is reverted even if inference raises an exception.- model.eval() is not optional for BatchNorm — it's a correctness requirement. Always switch to eval mode before inference.
- Test inference with batch size 1 in every CI pipeline to catch this class of bug early.
- Educate the team: training and eval are fundamentally different computational graphs for BatchNorm.
train() mode during inference. Force model.eval() and re-run.model.eval()
def __exit__(self, *args): model.train(self.prev)Key takeaways
train()/eval() switch is not cosmeticmodel.eval() call.Common mistakes to avoid
5 patternsForgetting model.eval() before inference
Using BatchNorm with tiny batches (< 8)
Overlooking SyncBatchNorm in DistributedDataParallel
Incorrect layer ordering (e.g., ReLU before BN)
Not updating running stats during fine-tuning
train() mode during fine-tuning so running stats update.Interview Questions on This Topic
Explain how Batch Normalisation helps mitigate the 'Dying ReLU' problem.
Frequently Asked Questions
That's Deep Learning. Mark it forged?
5 min read · try the examples if you haven't