RNNs & LSTMs — Sequential Data Processing
Some data has order: words in a sentence, steps in a time series, frames in a video. Recurrent Neural Networks process sequences by maintaining a hidden state — a memory that carries information from past steps into the future.
The Vanilla RNN
At each time step t, the RNN takes the current input xt and the previous hidden state ht-1, and produces a new hidden state ht:
ht = tanh(Whh · ht-1 + Wxh · xt + b)
yt = Why · ht + by
The same weights (Whh, Wxh) are reused at every time step — weight sharing over time.
The Vanishing Gradient Problem
When backpropagating through many time steps, gradients are multiplied by the weight matrix repeatedly. If weights are <1, gradients shrink exponentially → the model can't learn long-range dependencies. If weights are >1, gradients explode. This is why vanilla RNNs struggle with sequences longer than ~10 steps.
LSTM — Long Short-Term Memory
LSTMs solve the vanishing gradient problem by introducing a cell state (Ct) — a separate memory lane that flows through the sequence with only additive updates. Three gates control what information enters, leaves, and is remembered.
🚪 LSTM Gate Visualiser
Click a gate to see what it does and its equation.
The Three LSTM Gates Explained
🗑️ Forget Gate
ft = σ(Wf·[ht-1, xt] + bf) Sigmoid output (0–1). Multiplied element-wise with Ct-1. 0 = forget completely, 1 = keep everything. E.g., when reading a new document topic, reset accumulated context.
✏️ Input Gate
it = σ(Wi·[ht-1, xt] + bi) Decides how much new information to write to cell state. Combined with a tanh candidate: C̃t = tanh(Wc·[ht-1, xt] + bc)
💾 Cell Update
Ct = ft ⊙ Ct-1 + it ⊙ C̃t The cell state is simply: (forget old) + (add new). This additive update means gradients flow back almost unchanged — solving vanishing gradient.
📤 Output Gate
ot = σ(Wo·[ht-1, xt] + bo) ht = ot ⊙ tanh(Ct) Decides what part of the cell state to expose as the hidden state (output). Filters the full cell memory into the relevant output.
GRU — Gated Recurrent Unit
The GRU (2014) simplifies the LSTM by merging the cell state and hidden state into one, and combining forget + input gates into a single update gate. Fewer parameters, trains faster, often matches LSTM quality.
| Property | RNN | LSTM | GRU |
|---|---|---|---|
| Parameters (hidden=256) | ~65K | ~265K | ~200K |
| Long-range memory | Poor | Excellent | Good |
| Training speed | Fast | Slow | Medium |
| Best for | Short sequences | Complex patterns | Most practical use cases |
PyTorch LSTM — Text Sentiment Classifier
import torch
import torch.nn as nn
class SentimentLSTM(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=True # Process sequence left→right AND right→left
)
# Bidirectional: 2× hidden_dim
self.fc = nn.Linear(hidden_dim * 2, 1)
self.dropout = nn.Dropout(dropout)
def forward(self, x, lengths):
# x: (batch, seq_len) token ids
embedded = self.dropout(self.embedding(x))
# Pack padded sequence for efficiency
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths, batch_first=True, enforce_sorted=False
)
packed_out, (hidden, cell) = self.lstm(packed)
# Use final hidden states from both directions
# hidden shape: (num_layers * 2, batch, hidden_dim)
fwd = hidden[-2] # Forward LSTM final state
bwd = hidden[-1] # Backward LSTM final state
combined = torch.cat([fwd, bwd], dim=1) # (batch, hidden_dim * 2)
out = self.fc(self.dropout(combined))
return out.squeeze(1) # (batch,) — logit
model = SentimentLSTM(vocab_size=30000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss() # Binary cross-entropy for sentiment Seq2Seq — Encoder-Decoder Architecture
Seq2Seq maps a variable-length input sequence to a variable-length output sequence. The encoder LSTM compresses the input into a context vector; the decoder LSTM generates the output one token at a time. Used in: machine translation, summarisation, speech-to-text (before Transformers took over).
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
def forward(self, src):
embedded = self.embedding(src)
_, (hidden, cell) = self.lstm(embedded)
return hidden, cell # Context vector
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, token, hidden, cell):
# token: (batch, 1) — one token at a time
embedded = self.embedding(token)
out, (hidden, cell) = self.lstm(embedded, (hidden, cell))
prediction = self.fc(out.squeeze(1)) # (batch, vocab)
return prediction, hidden, cell
# Decoding loop (teacher forcing during training)
def decode(decoder, tgt, hidden, cell):
outputs = []
for t in range(tgt.shape[1]):
pred, hidden, cell = decoder(tgt[:, t:t+1], hidden, cell)
outputs.append(pred)
return torch.stack(outputs, dim=1) When to Use RNNs vs Transformers
Still Use RNNs For
- Real-time streaming data (RNNs process one step at a time)
- Very long sequences (>10K steps) where Transformer attention is too expensive
- Resource-constrained edge devices
- Time series with explicit temporal dependencies (e.g., sensor fusion)
Use Transformers Instead For
- NLP tasks (text, translation, summarisation) — Transformers dominate
- When you can parallelize training across GPUs
- Tasks where global context matters (both directions)
- When you have large amounts of pre-training data
Frequently Asked Questions
Are LSTMs obsolete now that we have Transformers?
Not entirely. LSTMs are still excellent for streaming time-series data where you process one step at a time without needing the full sequence (e.g., anomaly detection in sensor streams, real-time speech recognition on-device). For most NLP tasks though, Transformers have largely replaced LSTMs. New architectures like Mamba (State Space Models) are trying to combine the best of both worlds.
What's the difference between hidden state and cell state in LSTM?
The hidden state h_t is the "short-term" memory — the LSTM's output at each step, passed to the next step and to any downstream layer. The cell state C_t is the "long-term" memory — a separate state that only the LSTM gates can modify, never directly exposed to other layers. This separation is why LSTMs remember long-range dependencies better than RNNs.
How do I handle variable-length sequences in PyTorch?
Use padding + masking: pad all sequences to the same length with a padding token (0), then use pack_padded_sequence before the LSTM and pad_packed_sequence after. This tells PyTorch to skip computation on padding tokens, saving time and preventing the hidden state from being corrupted by padding. Always set padding_idx in your Embedding layer too.
Frequently Asked Questions
What will I learn here?
This page covers the core concepts and techniques you need to understand the topic and progress confidently to the next lesson.
How should I use this page?
Start with the overview, then follow the section links to deepen your understanding. Use the table of contents on the right to jump to specific sections.
What should I read next?
Use the navigation below to continue to the next lesson or explore related topics.