RNNs & LSTMs — Sequential Data Processing

Some data has order: words in a sentence, steps in a time series, frames in a video. Recurrent Neural Networks process sequences by maintaining a hidden state — a memory that carries information from past steps into the future.

🔄 Covers: Vanilla RNN · Vanishing Gradient · LSTM Gates · GRU · Seq2Seq · Time Series · Text · Audio · PyTorch

The Vanilla RNN

At each time step t, the RNN takes the current input xt and the previous hidden state ht-1, and produces a new hidden state ht:

ht = tanh(Whh · ht-1 + Wxh · xt + b)

yt = Why · ht + by

The same weights (Whh, Wxh) are reused at every time step — weight sharing over time.

The Vanishing Gradient Problem

When backpropagating through many time steps, gradients are multiplied by the weight matrix repeatedly. If weights are <1, gradients shrink exponentially → the model can't learn long-range dependencies. If weights are >1, gradients explode. This is why vanilla RNNs struggle with sequences longer than ~10 steps.

Step 1
Step 5
Step 10
Step 20
Gradient magnitude (vanishing)

LSTM — Long Short-Term Memory

LSTMs solve the vanishing gradient problem by introducing a cell state (Ct) — a separate memory lane that flows through the sequence with only additive updates. Three gates control what information enters, leaves, and is remembered.

🚪 LSTM Gate Visualiser

Click a gate to see what it does and its equation.

Forget Gate: f = σ(W_f·[h_{t-1}, x_t] + b_f) — decides what fraction of previous cell state to keep. Output near 0 = forget, near 1 = remember.

The Three LSTM Gates Explained

🗑️ Forget Gate

ft = σ(Wf·[ht-1, xt] + bf)

Sigmoid output (0–1). Multiplied element-wise with Ct-1. 0 = forget completely, 1 = keep everything. E.g., when reading a new document topic, reset accumulated context.

✏️ Input Gate

it = σ(Wi·[ht-1, xt] + bi)

Decides how much new information to write to cell state. Combined with a tanh candidate: C̃t = tanh(Wc·[ht-1, xt] + bc)

💾 Cell Update

Ct = ft ⊙ Ct-1 + it ⊙ C̃t

The cell state is simply: (forget old) + (add new). This additive update means gradients flow back almost unchanged — solving vanishing gradient.

📤 Output Gate

ot = σ(Wo·[ht-1, xt] + bo) ht = ot ⊙ tanh(Ct)

Decides what part of the cell state to expose as the hidden state (output). Filters the full cell memory into the relevant output.

GRU — Gated Recurrent Unit

The GRU (2014) simplifies the LSTM by merging the cell state and hidden state into one, and combining forget + input gates into a single update gate. Fewer parameters, trains faster, often matches LSTM quality.

PropertyRNNLSTMGRU
Parameters (hidden=256)~65K~265K~200K
Long-range memoryPoorExcellentGood
Training speedFastSlowMedium
Best forShort sequencesComplex patternsMost practical use cases

PyTorch LSTM — Text Sentiment Classifier

Python · LSTM Sentiment Analysis with PyTorch
import torch
import torch.nn as nn

class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=True   # Process sequence left→right AND right→left
        )
        # Bidirectional: 2× hidden_dim
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        # x: (batch, seq_len) token ids
        embedded = self.dropout(self.embedding(x))

        # Pack padded sequence for efficiency
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )
        packed_out, (hidden, cell) = self.lstm(packed)

        # Use final hidden states from both directions
        # hidden shape: (num_layers * 2, batch, hidden_dim)
        fwd = hidden[-2]   # Forward LSTM final state
        bwd = hidden[-1]   # Backward LSTM final state
        combined = torch.cat([fwd, bwd], dim=1)  # (batch, hidden_dim * 2)

        out = self.fc(self.dropout(combined))
        return out.squeeze(1)   # (batch,) — logit

model = SentimentLSTM(vocab_size=30000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy for sentiment

Seq2Seq — Encoder-Decoder Architecture

Seq2Seq maps a variable-length input sequence to a variable-length output sequence. The encoder LSTM compresses the input into a context vector; the decoder LSTM generates the output one token at a time. Used in: machine translation, summarisation, speech-to-text (before Transformers took over).

Python · Minimal Seq2Seq Encoder-Decoder
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        _, (hidden, cell) = self.lstm(embedded)
        return hidden, cell   # Context vector

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, token, hidden, cell):
        # token: (batch, 1) — one token at a time
        embedded = self.embedding(token)
        out, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(out.squeeze(1))   # (batch, vocab)
        return prediction, hidden, cell

# Decoding loop (teacher forcing during training)
def decode(decoder, tgt, hidden, cell):
    outputs = []
    for t in range(tgt.shape[1]):
        pred, hidden, cell = decoder(tgt[:, t:t+1], hidden, cell)
        outputs.append(pred)
    return torch.stack(outputs, dim=1)

When to Use RNNs vs Transformers

Still Use RNNs For

  • Real-time streaming data (RNNs process one step at a time)
  • Very long sequences (>10K steps) where Transformer attention is too expensive
  • Resource-constrained edge devices
  • Time series with explicit temporal dependencies (e.g., sensor fusion)

Use Transformers Instead For

  • NLP tasks (text, translation, summarisation) — Transformers dominate
  • When you can parallelize training across GPUs
  • Tasks where global context matters (both directions)
  • When you have large amounts of pre-training data

Frequently Asked Questions

Are LSTMs obsolete now that we have Transformers?

Not entirely. LSTMs are still excellent for streaming time-series data where you process one step at a time without needing the full sequence (e.g., anomaly detection in sensor streams, real-time speech recognition on-device). For most NLP tasks though, Transformers have largely replaced LSTMs. New architectures like Mamba (State Space Models) are trying to combine the best of both worlds.

What's the difference between hidden state and cell state in LSTM?

The hidden state h_t is the "short-term" memory — the LSTM's output at each step, passed to the next step and to any downstream layer. The cell state C_t is the "long-term" memory — a separate state that only the LSTM gates can modify, never directly exposed to other layers. This separation is why LSTMs remember long-range dependencies better than RNNs.

How do I handle variable-length sequences in PyTorch?

Use padding + masking: pad all sequences to the same length with a padding token (0), then use pack_padded_sequence before the LSTM and pad_packed_sequence after. This tells PyTorch to skip computation on padding tokens, saving time and preventing the hidden state from being corrupted by padding. Always set padding_idx in your Embedding layer too.

Frequently Asked Questions

What will I learn here?

This page covers the core concepts and techniques you need to understand the topic and progress confidently to the next lesson.

How should I use this page?

Start with the overview, then follow the section links to deepen your understanding. Use the table of contents on the right to jump to specific sections.

What should I read next?

Use the navigation below to continue to the next lesson or explore related topics.