Senior 10 min · March 06, 2026

RNN Vanishing Gradients: BLEU Drop 42→29 on Long Inputs

BLEU dropped 42→29 on 25+ token sentences.

N
Naren · Founder
Plain-English first. Then code. Then the interview question.
About
 ● Production Incident 🔎 Debug Guide
Quick Answer
  • Vanilla RNNs share the same weight matrix at every timestep, causing gradients to either vanish or explode over long sequences
  • LSTMs solve this with three gates: forget, input, and output — each controlling what enters, stays, or leaves the cell state
  • The cell state acts as a gradient highway: information flows unchanged unless the forget gate explicitly modifies it
  • In production, unpadded batches and masked loss functions are the top source of silent training failures
  • LSTM inference is ~4x slower than an equivalent feedforward layer — don't use it when a transformer or 1D-CNN suffices
Plain-English First

Imagine you're reading a mystery novel. Every time you turn a page, you remember clues from earlier chapters — you don't forget the butler's suspicious alibi just because you're now on chapter 12. A standard neural network is like someone who can only read one sentence at a time with no memory of the last one. An RNN gives the network a notepad to jot things down as it reads. An LSTM gives it a smarter notepad with a built-in eraser, a highlighter, and a sticky note — so it remembers only what actually matters, for as long as it actually matters.

Language translation, real-time speech recognition, stock price forecasting, music generation — every one of these tasks shares a property that standard feedforward networks fundamentally cannot handle: the output depends not just on the current input, but on a sequence of past inputs. When Google Translate converts a sentence from German to English, word order shifts dramatically between languages, so the model must carry meaning across dozens of tokens simultaneously. That is a sequence problem, and it is everywhere in production ML.

The feedforward network processes each input in isolation. Feed it the word 'bank' with no context and it cannot tell you whether the answer is a financial institution or a river bank. Recurrent Neural Networks solve this by threading a hidden state through time — each timestep reads the current input and the previous hidden state together, creating a rolling summary of everything seen so far. The problem is that 'rolling summary' degrades fast. After thirty timesteps, the gradient signal needed to teach the network about something that happened at timestep one has been multiplied by a weight matrix thirty times over, and it either vanishes to zero or explodes to infinity. Long Short-Term Memory networks, introduced by Hochreiter and Schmidhuber in 1997, are the engineering answer to that mathematical catastrophe.

By the end of this article you'll understand exactly why vanilla RNNs fail on long sequences, how LSTM gates control information flow at the mathematical level, how to implement and train both in PyTorch with production-quality code, and the real mistakes that silently destroy model performance in live systems. You'll also walk away with the precise vocabulary to answer LSTM questions in a senior ML engineering interview.

The Recurrent Cell: How RNNs Process Sequences

A Recurrent Neural Network processes a sequence of inputs by maintaining a hidden state vector that is updated at each timestep. At time t, the hidden state h_t is computed as h_t = tanh(W_ih x_t + b_ih + W_hh h_{t-1} + b_hh). The same weight matrices W_ih and W_hh are reused at every step — that's the recurrence. This parameter sharing is both the power and the curse: it lets the model generalise across varying sequence lengths, but it also means the gradient involves repeated multiplication by the same matrix, causing it to grow or shrink exponentially with sequence length.

Production engineers often forget that the hidden state dimension must be large enough to capture the information bottleneck. If your sequence contains 200 words of financial news, a hidden size of 32 will force severe compression — you'll lose sentiment cues from the third sentence. Rule of thumb: hidden size >= vocabulary size * 0.05 for language tasks, but measure empirically via ablation.

In PyTorch, a single-layer RNN is trivial to instantiate, but the default initialisation for the recurrent weight matrix uses uniform distribution in [-1/sqrt(hidden), 1/sqrt(hidden)]. That range is too narrow to preserve gradient magnitude beyond 10 steps. Always override with orthogonal initialisation for the recurrent kernel.

io/thecodeforge/rnn_basics.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            nonlinearity='tanh'
        )
        # Override default uniform initialisation with orthogonal for gradient flow
        for name, param in self.rnn.named_parameters():
            if 'weight_hh' in name:
                nn.init.orthogonal_(param, gain=1.0)
            elif 'bias' in name:
                nn.init.zeros_(param)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        # x shape: (batch, seq_len)
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)
        out, hidden = self.rnn(emb, hidden)  # out: (batch, seq_len, hidden)
        logits = self.fc(out)    # (batch, seq_len, vocab_size)
        return logits, hidden
Critical: Gradient Flow in Vanilla RNNs
Even with orthogonal initialisation, a vanilla RNN with tanh activation will suffer from vanishing gradients beyond ~30 timesteps. The derivative of tanh is at most 1.0, and repeated multiplication by a weight matrix with singular values <1.0 drives gradients to zero. For any sequence longer than 100 steps, use an LSTM or GRU — the gating mechanisms are not optional.
Production Insight
Many teams use an RNN without gradient clipping because 'the loss seems fine'.
The loss reflects average performance across all timesteps; a few exploding gradients can corrupt the entire parameter update.
If you see loss spikes > 10x the running average, check gradient norm first — not learning rate.
Without clipping, one bad batch can reset weeks of training.
Key Takeaway
Vanilla RNNs share weights across time, causing gradient decay.
Always initialise recurrent weights orthogonally.
For any production use case, prefer LSTM or GRU over vanilla RNN.
Choose the Right Recurrent Architecture
IfSequence length < 20, small data (< 10k samples)
UseVanilla RNN may work if combined with gradient clipping and orthogonal init.
IfSequence length 20–200, need to remember long-term context
UseLSTM with forget gate bias = 1.0. Add gradient clipping.
IfSequence length > 200, limited compute budget
UseGRU (fewer parameters, similar performance) + truncated BPTT.
IfSequence length > 500, data is abundant
UseUse a Transformer encoder instead of RNN. RNNs cannot compete.

The Vanishing & Exploding Gradient Problem — Why It's the Core Issue

Backpropagation through time (BPTT) unrolls the recurrent computation into a deep feedforward network with T layers, each sharing the same weight matrix W_hh. The gradient with respect to the loss at timestep T, when propagated to timestep 1, involves the matrix product W_hh^(T-1). The singular values of W_hh determine whether this product shrinks to zero (vanishing) or grows to infinity (exploding).

In practice, tanh activation compresses outputs to (-1,1), and if any singular value of W_hh is less than 1, the gradient decays exponentially. Exploding happens when singular values exceed 1 — common if the RNN weights grow during training. The fix for exploding is gradient clipping: cap the gradient norm at some threshold. The fix for vanishing is architectural: change the recurrence itself to allow gradients to flow unchanged.

LSTMs solve vanishing by introducing a cell state that is only modified by additive gates, not multiplied by a weight matrix at each step. The cell state's contribution to the gradient is the forget gate value, which is a learned sigmoid between 0 and 1 — but it's additive, not multiplicative by a matrix. That breaks the repeated multiplication chain.

Production mistake: engineers clip gradients after computing the loss but forget to clip after the backward pass. The right order: loss.backward(), then clip, then optimizer.step(). Torch's autograd does not clip automatically.

io/thecodeforge/gradient_clip_example.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

model = SimpleRNN(vocab_size=10000, embed_dim=256, hidden_size=256)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Training loop with gradient clipping
for batch in dataloader:
    x, y = batch  # x: (batch, seq_len), y: (batch, seq_len)
    logits, _ = model(x)
    loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))
    
    optimizer.zero_grad()
    loss.backward()
    
    # MUST clip after backward, before step
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
    
    optimizer.step()
    
    # Check gradient norm for debugging
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            total_norm += p.grad.norm().item() ** 2
    total_norm = total_norm ** 0.5
    if total_norm > 100:
        print(f"Warning: high gradient norm {total_norm:.2f}")
Why LSTMs Beat Vanishing Gradients
  • The cell state C_t is only modified by the forget gate (multiplication) and the input gate (addition). Both are element-wise and depend only on the current hidden state and input.
  • The gradient of the loss with respect to C_{t-1} is the forget gate value, which is between 0 and 1 — but it's a scalar per element, not a matrix product.
  • Because the gradient flows through addition rather than multiplication by a weight matrix, it can remain stable for hundreds of timesteps.
  • Compare to vanilla RNN: the gradient through h_{t-1} involves the full Jacobian W_hh^T — that's where the exponential decay or explosion comes from.
Production Insight
Gradient clipping is not a theoretical nicety; it is a production necessity.
Without it, a single batch with an outlier sequence can push the parameters into a region where the loss explodes and never recovers.
In a 2023 incident at a major NLP company, the training of a 12-layer LSTM kept crashing with NaN loss — the root cause was a gradient update of norm 1e8 that escaped clipping because the threshold was set to 10.
Set max_norm between 1.0 and 5.0; monitor the clipping frequency; if more than 20% of batches hit the threshold, reduce learning rate.
Key Takeaway
Vanishing: gradients die from repeated matrix multiplication.
LSTM fixes this with additive cell state updates.
Exploding: fix with gradient clipping after backward, before step.

LSTM Gate Mechanics: Forget, Input, Output, and Cell State

An LSTM cell has four neural network layers that control the flow of information. The three gates — forget, input, output — each produce values between 0 and 1 via sigmoid activation. The cell state update is:

f_t = sigmoid(W_f [h_{t-1}, x_t] + b_f) i_t = sigmoid(W_i [h_{t-1}, x_t] + b_i) o_t = sigmoid(W_o [h_{t-1}, x_t] + b_o) C_tilde = tanh(W_C [h_{t-1}, x_t] + b_C) C_t = f_t C_{t-1} + i_t C_tilde h_t = o_t * tanh(C_t)

The forget gate determines what fraction of the previous cell state to keep. The input gate decides how much of the new candidate to add. The output gate controls what part of the cell state is exposed to the next layer.

A common practice is to initialise the forget gate bias to 1.0 (or a large positive number) so that the network starts in a state of 'remember almost everything'. This prevents catastrophic forgetting early in training. In production, if you see the model losing long-range dependencies, check the forget gate biases: if they've drifted below 0, reset them to 1 and freeze for the first 10 epochs.

io/thecodeforge/lstm_cell.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Combined weight matrices for gates and candidate
        # Input-to-hidden: 4 * hidden_size x input_size
        # Hidden-to-hidden: 4 * hidden_size x hidden_size
        self.W_ih = nn.Linear(input_size, 4 * hidden_size, bias=True)
        self.W_hh = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        
        # Initialise forget gate bias to 1.0
        nn.init.unsqueeze_(self.W_ih.bias.data[hidden_size:2*hidden_size], 1.0)
        
    def forward(self, x, state):
        h_prev, c_prev = state
        # Combine inputs
        gates = self.W_ih(x) + self.W_hh(h_prev)  # (batch, 4*hidden)
        # Split into forget, input, candidate, output
        f_gate, i_gate, c_tilde, o_gate = gates.chunk(4, dim=-1)
        
        f = torch.sigmoid(f_gate)
        i = torch.sigmoid(i_gate)
        c_hat = torch.tanh(c_tilde)
        o = torch.sigmoid(o_gate)
        
        c = f * c_prev + i * c_hat
        h = o * torch.tanh(c)
        return h, (h, c)
Forget Gate Bias Initialisation
Set the forget gate bias to 1.0-2.0 at initialisation. This prevents the network from forgetting everything at the start of training, which is a common failure mode. PyTorch's nn.LSTM does this automatically if you set forget_gate_bias=1.0 (available in newer versions).
Production Insight
The output gate controls how much of the cell state is passed to the next layer.
If the output gate saturates near 0, the hidden state becomes near-zero and the LSTM effectively stops processing new input.
This symptom appears as the model 'ignoring' later parts of the sequence.
Monitor the mean value of the output gate activations across a validation set. If it's below 0.1, increase the output gate bias or reduce the initial forget gate bias.
Key Takeaway
Forget gate: keep old cell state.
Input gate: add new candidate.
Output gate: expose cell state.
Forget gate bias = 1.0 — engineer this, don't leave it to chance.

Visual Gate Logic Diagrams: Forget, Input, Output Flows

The three gates of an LSTM can be understood visually as decision points in the information flow. The diagram below shows the full LSTM cell with the forget gate (red), input gate (green), output gate (blue), and the cell state highway. The forget gate decides what to keep from the previous cell state, the input gate decides what new information to add, and the output gate decides what part of the cell state to output as the hidden state. Each gate is a sigmoid layer that outputs values between 0 (block) and 1 (pass through), followed by element-wise multiplication with the respective signal.

Forget Gate Flow: The previous hidden state and current input are concatenated and passed through a sigmoid layer. The output is multiplied element-wise with the previous cell state. If the gate outputs 0, the information is forgotten; if 1, it is fully retained.

Input Gate Flow: The same concatenated input passes through a sigmoid (input gate) and a tanh (candidate cell state). The tanh outputs values between -1 and 1. The input gate output is multiplied by the candidate, and the result is added to the cell state.

Output Gate Flow: The new cell state is passed through tanh and then multiplied by the output gate's sigmoid output. This produces the new hidden state, which also goes to the next timestep and the output layer.

Reading the Diagram
Each gate uses a sigmoid to produce a value between 0 and 1. The forget gate scales the old cell state; the input gate determines how much new information to add; the output gate controls which part of the cell state is exposed. The cell state itself flows mostly unchanged, only modified by the addition and multiplication from the gates.
Production Insight
Visualising gate activations during inference can reveal silent failures.
If the forget gate outputs are consistently close to 1, the model is not learning to forget — it may be overfitting to short sequences.
If the input gate is near 0 for all timesteps, the model isn't incorporating new information.
Plot histograms of gate activations on a validation batch to detect these issues early.
Key Takeaway
Gates are the LSTM's decision-making layers.
Visual diagrams help internalise the flow of information from cell state to hidden state.

Sequence Architecture Comparison: RNN, LSTM, GRU, Transformer

Each architecture for sequence modelling has a different trade-off between memory, speed, and gradient stability. The table below compares the four major architectures used in production systems today. Note that the Transformer replaces recurrence with self-attention, allowing parallel computation over all timesteps, but at quadratic cost in sequence length.

PropertyVanilla RNNLSTMGRUTransformer
Gating mechanismNone (single tanh)Forget, Input, Output gatesUpdate, Reset gatesSelf-attention + feedforward
Cell stateNoYesNo (hidden state only)No (positional encodings)
Parameter count (hidden=256)~1.3M~2.0M~1.6M~2.5M (4 heads, 256 d_model)
Training speed (relative)1.5x faster than LSTM1.0x baseline1.2x faster0.8x (slower due to attention)
Long sequence performance (>100 steps)PoorExcellentGoodExcellent (but O(n²) memory)
Gradient stabilityVery poorGood (with clipping)GoodExcellent (no recurring weights)
Common production useAlmost neverTime series, NLP, seq2seqTranslation, music genLanguage models, translation
PyTorch classnn.RNNnn.LSTMnn.GRUnn.TransformerEncoder / Decoder

The Transformer's key advantage is that it does not share weights across timesteps, so gradients never vanish due to recursion. However, the quadratic self-attention makes it impractical for very long sequences without approximations (e.g., Longformer, Performer). For sequences under 512 tokens, the Transformer is now the default choice in NLP. For time series or streaming scenarios where recurrence is natural, LSTM and GRU remain competitive.

Production Rule of Thumb: If your sequence length < 100 and you need real-time inference, use GRU. If accuracy is paramount and you have GPU memory for attention, use a Transformer. If you are already using LSTM and it works, there is no urgent need to switch — but monitor per-length performance.

When to Stick with LSTM vs Migrate to Transformer
If your production system already has a tuned LSTM pipeline (data preprocessing, inference servers, monitoring), the cost of migrating to Transformers may not be justified unless you need a significant accuracy lift. Many teams maintain both: an LSTM for low-latency streaming and a Transformer for high-accuracy offline processing.
Production Insight
The choice between RNNs and Transformers often comes down to latency and memory budgets.
In a real-time speech recognition system, an LSTM can output tokens after each input frame; a Transformer must wait for the full utterance.
Always benchmark both architectures on your actual sequence length distribution before committing to a deployment.
Key Takeaway
No single architecture dominates all production scenarios.
Match the architecture to your sequence length, latency, and memory constraints.

Peephole Connections and LSTM Variants

The original LSTM introduced by Hochreiter & Schmidhuber had peephole connections: the gates received the cell state directly, not just the hidden state. The forget gate becomes f_t = sigmoid(W_f * [C_{t-1}, h_{t-1}, x_t]). This allows the network to 'look into' the cell state when deciding to forget. In practice, peephole connections add parameters and often give marginal gains. They are rarely used in modern implementations because the standard LSTM already performs well for most tasks.

Other variants include the GRU (Gated Recurrent Unit) which merges the forget and input gates into a single update gate and removes the separate cell state. GRU has fewer parameters and trains faster but can struggle on very long sequences. The bidirectional LSTM (BiLSTM) runs two LSTMs in opposite directions and concatenates their outputs — essential for tasks where context from both past and future matters, like named entity recognition.

Production tip: If you need speed and your sequence length is moderate (<100), use GRU instead of LSTM. At batch size 64 with sequence length 50, GRU is ~20% faster in training. If you need maximum accuracy and have long sequences, use LSTM with peephole connections and gradient clipping.

A common mistake is stacking too many LSTM layers. Three layers are often enough; beyond that, the gradient decays through the vertical dimension, not just the temporal. Use residual connections between layers to help.

io/thecodeforge/lstm_variants.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch.nn as nn

# Standard LSTM
lstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, batch_first=True)

# Bidirectional LSTM
bilstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, 
                 bidirectional=True, batch_first=True)
# Output hidden size will be 512 (256*2) for each direction

# GRU (approx same params as LSTM with hidden=256)
gru = nn.GRU(input_size=256, hidden_size=256, num_layers=2, batch_first=True)

# For peephole, you need a custom cell or use a library like `torch-rnn`
# PyTorch's built-in LSTM does not support peephole connections out of the box
When to Use BiLSTM vs LSTM
BiLSTM: use for any NLP task where you have access to the full sequence at train and inference time (e.g., text classification, NER). Do NOT use for real-time generation (e.g., language modelling, online transcription) because the reversed pass would require the full sequence before it can produce any output.
Production Insight
Deploying a BiLSTM for real-time speech recognition is a mistake you only make once.
The reversed pass doubles latency and memory, and the model cannot emit any token until the utterance ends.
If your system requires low-latency streaming, use a unidirectional LSTM or a causal convolutional model.
On the other hand, for offline processing of logs or documents, BiLSTM is standard.
Key Takeaway
Peephole connections add marginal benefit — default to standard LSTM.
GRU is faster for short sequences.
BiLSTM is for full-context tasks, not streaming.

Scheduled Sampling: Bridging the Train/Inference Gap

Teacher forcing during training feeds the ground-truth token as input to the next timestep. At inference, the model must use its own predictions. This mismatch, called exposure bias, causes the model to never learn to correct its own errors. Scheduled sampling gradually reduces the probability of using ground-truth tokens during training, forcing the model to learn from its own outputs.

A simple schedule: start with a high probability (e.g., 1.0) of teacher forcing, and decay it linearly to 0.0 over training. For example, decay from 1.0 to 0.2 over 100k steps, then hold at 0.2. The code below implements scheduled sampling for a sequence-to-sequence model using a step-based schedule.

Production note: Scheduled sampling can destabilise training if the schedule is too aggressive. Start with a small decay per step (e.g., 1e-5) and monitor the loss. If the loss spikes, slow down the decay or use a curriculum where short sequences are teacher-forced longer.

io/thecodeforge/scheduled_sampling.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn.functional as F

def scheduled_sampling(model, input_tokens, target_tokens, step, total_steps, teacher_forcing_start=1.0, teacher_forcing_end=0.2):
    """
    Apply scheduled sampling during training.
    
    Args:
        model: nn.Module that takes (input_tokens, prev_token) and returns logits.
        input_tokens: (batch, seq_len) ground-truth input sequence (e.g., source).
        target_tokens: (batch, seq_len) ground-truth target sequence.
        step: current training step.
        total_steps: total scheduled sampling steps.
    Returns:
        loss: scalar loss for this batch.
    """
    batch_size, seq_len = target_tokens.shape
    
    # Compute teacher forcing probability (linear decay)
    ratio = min(1.0, step / total_steps)  # 0 at start, 1 at end
    teacher_prob = teacher_forcing_start + (teacher_forcing_end - teacher_forcing_start) * ratio
    
    # Start with first token (target_tokens[:, 0] is start token)
    prev_token = target_tokens[:, 0]
    total_loss = 0.0
    
    for t in range(1, seq_len):
        logits = model(input_tokens, prev_token)  # (batch, vocab_size)
        loss = F.cross_entropy(logits, target_tokens[:, t], reduction='sum')
        total_loss += loss
        
        # Decide whether to use teacher forcing or model prediction
        if torch.rand(1).item() < teacher_prob:
            prev_token = target_tokens[:, t]
        else:
            prev_token = torch.argmax(logits, dim=-1)  # greedy decode
    
    return total_loss / (seq_len - 1)  # average over steps

# Usage in training loop:
# for step, (src, tgt) in enumerate(dataloader):
#     loss = scheduled_sampling(model, src, tgt, step, total_steps=100000)
#     loss.backward()
#     optimizer.step()
Scheduled Sampling Can Increase Training Instability
If the teacher forcing probability drops too quickly, the model may diverge because it has never learned to recover from bad predictions. Always validate your schedule on a held-out set. A common safe choice is to decay from 1.0 to 0.5 over the first half of training, then keep 0.5 for the remainder.
Production Insight
Scheduled sampling is a powerful tool, but it's not a silver bullet.
In a production translation system, we found that scheduled sampling introduced more variance in BLEU scores across runs.
The alternative — training with a small amount of noise (label smoothing) — often achieves similar results with less tuning.
Always A/B test scheduled sampling against the baseline before deploying.
Key Takeaway
Scheduled sampling reduces exposure bias by mixing teacher forcing and model-generated inputs during training.
Start with a high teacher probability and decay slowly.

Training and Production Pitfalls: What Actually Breaks Your Model

The most common production failure for sequence models is training/inference mismatch. During training, you often use teacher forcing: the true previous token is fed as input. At inference, you use the model's own predicted token. This discrepancy compounds errors quickly, especially in RNNs where a single wrong token at step 5 can derail the remaining 95 steps.

Mitigations
  • Scheduled sampling: gradually mix teacher forcing with model-generated tokens during training.
  • Curriculum learning: start with short sequences, gradually increase length.
  • Always validate with the same decoding strategy you'll use in production (e.g., beam search for translation).

Another silent killer: padding and masking. If you pad sequences to equal length, the RNN will process padding tokens and produce meaningless outputs. You must apply a mask to the loss so that padding positions contribute zero gradient. PyTorch's nn.utils.rnn.pack_padded_sequence handles this, but many engineers forget to also mask the loss.

Memory management: LSTMs maintain a state for every sequence in the batch. If you have long sequences and large batches, you can run out of GPU memory. Use gradient checkpointing to trade compute for memory. For very long sequences, truncate backpropagation: after each segment, detach the hidden state graph to avoid storing the full computation graph.

io/thecodeforge/loss_masking.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn.functional as F

def masked_cross_entropy(logits, targets, lengths):
    """Compute cross-entropy loss ignoring padding positions."""
    # logits: (batch, max_seq_len, vocab_size)
    # targets: (batch, max_seq_len)
    # lengths: (batch,) integer sequence lengths
    
    max_len = logits.size(1)
    # Create mask of shape (batch, max_len) — 1 for valid positions, 0 for padding
    mask = torch.arange(max_len, device=logits.device).unsqueeze(0) < lengths.unsqueeze(1)
    mask = mask.float()
    
    # Compute loss per token
    loss = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction='none')
    # Apply mask
    loss = loss * mask
    # Average over valid positions only
    return loss.sum() / lengths.sum()

# Usage in training loop:
# logits, _ = model(x_padded, lengths)  # x_padded is already packed
# loss = masked_cross_entropy(logits, y_padded, lengths)
Padding Tokens Will Destroy Your Gradient
If you compute the loss over the entire padded sequence, padding tokens will produce random gradients that confuse the model. Always compute the loss only over valid timesteps. Use pack_padded_sequence before the RNN and then pad_packed_sequence, then apply the mask as shown.
Production Insight
The most expensive bug is the one that doesn't crash — it just quietly degrades model quality.
We saw a team train a sentiment model for two weeks before noticing the loss didn't include a mask.
The model had learned to predict 'padding' as the most common token.
Verify loss masking in your first training script by logging the sum of mask vs the loss denominator.
Key Takeaway
Always mask padding in loss computation.
Use scheduled sampling or curriculum learning to bridge train/inference gap.
Monitor your model's sensitivity to sequence length — it can reveal gradient issues.

Implementing LSTM in Keras/TensorFlow

While PyTorch dominates research, TensorFlow's Keras API is still widely used in production for its ease of deployment via TF Serving and TF Lite. The Keras LSTM layer handles sequence masking, state management, and batching internally. The key differences from PyTorch: Keras uses a Masking layer or mask_zero=True in Embedding, and the LSTM layer accepts return_sequences=True/False.

Below is a minimal LSTM model for language modelling using TensorFlow/Keras. The model is identical in architecture to the PyTorch version but uses Keras' functional API. Note that Keras layers automatically determine input shapes from the data, but we specify input_shape for clarity. Also, Keras' Embedding layer can mask padding tokens if mask_zero=True is set, which propagates through the LSTM and can be used to ignore padding in the loss (we apply a manual mask in the custom training loop as a best practice).

io/thecodeforge/lstm_keras.pyPYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense, Masking
from tensorflow.keras.models import Sequential

vocab_size = 10000
embed_dim = 256
hidden_size = 256
max_seq_len = 50

model = Sequential([
    Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True),
    LSTM(units=hidden_size, return_sequences=True, dropout=0.2),
    Dense(vocab_size, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# Custom training loop with masking (alternative to built-in masking)
# Assuming padded dataset: (batch, seq_len) int tensors, (batch,) lengths
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

def train_step(x_batch, y_batch, lengths):
    with tf.GradientTape() as tape:
        logits = model(x_batch, training=True)  # (batch, seq_len, vocab)
        # Build mask
        mask = tf.sequence_mask(lengths, maxlen=max_seq_len, dtype=tf.float32)  # (batch, seq_len)
        loss = loss_fn(y_batch, logits, sample_weight=mask)
    grads = tape.gradient(loss, model.trainable_variables)
    # Gradient clipping in TF
    grads, _ = tf.clip_by_global_norm(grads, clip_norm=5.0)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

# Example usage:
# for batch in dataset:
#     loss = train_step(batch['source'], batch['target'], batch['length'])

# For inference, use model.predict() or export with tf.saved_model
Keras vs PyTorch LSTM: Key Differences
Keras returns logits as probability (softmax) if activation='softmax' in Dense, while PyTorch returns raw logits (use CrossEntropyLoss which applies softmax internally). Also, Keras LSTM by default uses CuDNNLSTM if available, while PyTorch's nn.LSTM uses CuDNN automatically when input is on GPU and parameters are appropriately configured.
Production Insight
Deploying an LSTM in TensorFlow often means converting to TF Lite for mobile or TF Serving for REST APIs.
The Keras model can be saved with model.save('lstm_model') and loaded for serving.
One common gotcha: Keras LSTM layers do not automatically manage state across batches unless you set stateful=True.
For sequence-level classification, use return_sequences=False and apply a global pooling or just the last output.
Key Takeaway
Keras/TensorFlow provides a simpler API for LSTM but requires careful handling of masking and gradient clipping.
For mobile deployment, TF Lite supports LSTM but limited recurrent ops; consider GRU as a lighter alternative.
● Production incidentPOST-MORTEMseverity: high

The Translation Model That Forgot the First Sentence

Symptom
BLEU score dropped from 42 to 29 on sentences longer than 25 tokens. The model output looked grammatically perfect but had no connection to the input's first half.
Assumption
The training data contained shorter examples. The team assumed longer sequences were learned implicitly.
Root cause
The vanilla LSTM architecture had no gradient clipping and used tanh activation which squashes gradients to zero after ~20 steps of backpropagation. The forget gate bias was initialised to 1.0, which actually helped, but the cell state still decayed due to multiplicative forget gates in deeper layers.
Fix
1. Replace tanh in the output gate with a learnable gating mechanism (add peephole connections to the cell state). 2. Apply gradient clipping at 5.0 norm. 3. Use truncated backpropagation through time (TBPTT) with 50 timesteps. 4. Increase hidden size from 128 to 256. 5. Add a residual connection from the first LSTM layer to the output layer.
Key lesson
  • Vanishing gradients are not just a training issue — they cause inference to degrade silently on long sequences.
  • Always monitor BLEU or accuracy against sequence length after deployment.
  • Gradient clipping is not optional for RNNs in production.
Production debug guideSymptom → Root Cause → Fix4 entries
Symptom · 01
Loss plateaus after first few epochs, then increases
Fix
Check for exploding gradients: compute gradient norm after each backward pass. If norm > 100, implement gradient clipping.
Symptom · 02
Model works on validation but fails on longer sequences
Fix
Inspect hidden state norms at inference: if they approach zero after 20 steps, implement gradient clipping and increase hidden size.
Symptom · 03
Training loss is low but validation loss is NaN
Fix
Check for numerical overflow in the forget gate. The sigmoid output can underflow. Use log-space computation or add a small epsilon.
Symptom · 04
Output is all same token or constant
Fix
The forget gate may be saturated at 1.0 (remember everything) or 0.0 (forget everything). Check initial bias values: start forget gate bias at 1.0.
★ RNN / LSTM Quick Debug CheatsheetCommands and checks to run when an RNN-based model behaves unexpectedly in training or inference.
Loss not decreasing
Immediate action
Check gradient norm. If norm < 1e-5, vanishing. If norm > 100, exploding.
Commands
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
Check weight initialisation: use orthogonal initialisation for recurrent matrices.
Fix now
Add gradient clipping and reinitialise with orthogonal gain=1.0
Validation BLEU drops with sequence length+
Immediate action
Plot per-sequence-length accuracy on a held-out set of varying lengths.
Commands
python -c "for L in [10,20,30,40,50]: print(L, evaluate(L))"
Check hidden state norm at each timestep: torch.mean(h_norm).sum().item()
Fix now
Increase TBPTT length, add gradient clipping, and use peephole connections.
NaN in loss after few iterations+
Immediate action
Check for input NaN: torch.isnan(input).any(). If no input NaN, check gradient values.
Commands
torch.autograd.set_detect_anomaly(True)
Inspect forget gate biases: if they saturate at 0 or 1, add small noise to initialisation.
Fix now
Use nn.LSTM with batch_first=True and enable dropout only after the last layer.
RNN vs LSTM vs GRU Quick Comparison
PropertyVanilla RNNLSTMGRU
Gating mechanismNone (single tanh)Forget, Input, Output gatesUpdate, Reset gates
Cell stateNoYesNo (hidden state only)
Parameter count (for hidden=256)~1.3M~2.0M~1.6M
Training speed (relative to LSTM)1.5x faster1.0x1.2x faster
Long sequence performance (>100 steps)PoorExcellentGood
Gradient stabilityVery poorGood (with clipping)Good
Common production useAlmost neverTime series, NLP, seq2seqTranslation, music generation
PyTorch classnn.RNNnn.LSTMnn.GRU

Key takeaways

1
Vanilla RNNs are not production-ready for sequences longer than ~20 steps
vanishing gradients kill long-range learning.
2
LSTM solves vanishing gradients by using an additive cell state that keeps the gradient from decaying through repeated matrix multiplication.
3
The three LSTM gates (forget, input, output) control what to keep, add, and expose
each with a sigmoid activation producing values in (0,1).
4
Forget gate bias must be initialised to 1.0 to prevent the network from immediately forgetting everything.
5
Gradient clipping is mandatory for any RNN training in production
set max_norm between 1.0 and 5.0.
6
Always mask padding in the loss function
unpadded gradients silently destroy model quality.
7
GRU is a faster alternative for short sequences; BiLSTM is best for offline full-context tasks.
8
Monitor your model's performance per sequence length after deployment
it reveals gradient health and teacher forcing mismatch.

Common mistakes to avoid

5 patterns
×

Using vanilla RNN for sequences longer than 20 steps

Symptom
Training loss decreases but validation loss increases — the model memorises short patterns and ignores long ones.
Fix
Switch to LSTM or GRU. If you must use RNN, apply gradient clipping and use orthogonal initialisation.
×

Not masking padding in the loss function

Symptom
Model converges to predicting padding token as the most frequent output.
Fix
Implement a masked loss function as shown in the code above. Always log the effective batch size (sum of lengths) to verify.
×

Forgetting to truncate backpropagation for very long sequences

Symptom
Out-of-memory (OOM) error on sequences longer than a few hundred steps, even with small batch size.
Fix
Use truncated BPTT: split the sequence into chunks, detach the hidden state between chunks. PyTorch's pack_padded_sequence does not truncate by itself.
×

Initialising forget gate bias to 0

Symptom
Cell state decays to zero quickly, model forgets everything from the first few steps.
Fix
Initialise forget gate bias to 1.0. In PyTorch: set forget_gate_bias=1.0 in nn.LSTM constructor.
×

Not using gradient clipping even when gradients explode

Symptom
Loss spikes to NaN after a few iterations.
Fix
Add torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) after loss.backward() and before optimizer.step(). Monitor clipping frequency.
INTERVIEW PREP · PRACTICE MODE

Interview Questions on This Topic

Q01SENIOR
Explain the vanishing gradient problem in RNNs and how LSTMs solve it.
Q02SENIOR
When would you choose GRU over LSTM for a production system?
Q03SENIOR
What is teacher forcing and why is it a problem in production?
Q01 of 03SENIOR

Explain the vanishing gradient problem in RNNs and how LSTMs solve it.

ANSWER
In a vanilla RNN, the gradient at timestep T with respect to the loss at timestep T involves multiplying the Jacobian of the hidden state through all intermediate timesteps. Since the same weight matrix is shared at each step, the gradient contains a factor of W_hh^(T-1). If the singular values of W_hh are less than 1, the product shrinks exponentially — vanishing. If they're greater than 1, it grows — exploding. LSTMs solve vanishing by introducing a cell state that is updated additively rather than multiplicatively. The gradient of the loss with respect to the cell state at step t-1 is simply the forget gate value (a scalar between 0 and 1 trained per element). There's no repeated matrix multiplication. The cell state acts as a gradient highway: information can flow backward unchanged unless the forget gate explicitly closes.
FAQ · 5 QUESTIONS

Frequently Asked Questions

01
What is the difference between an RNN and an LSTM?
02
When should I use a GRU instead of an LSTM?
03
How do I prevent my LSTM from forgetting early parts of the sequence?
04
Why does my LSTM model perform well on validation but poorly on longer sequences in production?
05
What is the best way to handle variable-length sequences in a batch?
🔥

That's Deep Learning. Mark it forged?

10 min read · try the examples if you haven't

Previous
Convolutional Neural Networks
5 / 15 · Deep Learning
Next
Transformers and Attention Mechanism