intermediateDeep Learning

Understand RNNs, LSTMs, and GRUs - neural networks designed for sequential data that maintain memory across time steps.

rnnlstmgrusequence-modelingneural-networks

Recurrent Neural Networks (RNNs)

Recurrent Neural Networks are designed for sequential data. Unlike feedforward networks, RNNs have connections that loop back, allowing them to maintain a "memory" of previous inputs.

The Basic Idea

Process sequences one element at a time, maintaining hidden state:

Input:  x₁ → x₂ → x₃ → x₄
         ↓    ↓    ↓    ↓
Hidden: h₁ → h₂ → h₃ → h₄
                        ↓
Output:                 y

The hidden state carries information from previous time steps.

Vanilla RNN

Forward Pass

hₜ = tanh(Wₓₓxₜ + Wₕₕhₜ₋₁ + b)
yₜ = Wₕᵧhₜ + bᵧ

Where:

  • xₜ: input at time t
  • hₜ: hidden state at time t
  • yₜ: output at time t
  • W: weight matrices (shared across time!)

Weight Sharing

Same weights used at every time step:

  • Handles variable-length sequences
  • Learns patterns regardless of position
  • Fewer parameters than unrolled feedforward

The Vanishing Gradient Problem

During backpropagation through time (BPTT):

∂L/∂h₁ = ∂L/∂h₄ × ∂h₄/∂h₃ × ∂h₃/∂h₂ × ∂h₂/∂h₁
                    ↑          ↑          ↑
              Each term can be < 1

Gradients multiply at each step:

  • If < 1: Gradients vanish (can't learn long-term dependencies)
  • If > 1: Gradients explode (training unstable)

Vanilla RNNs struggle with sequences longer than ~10-20 steps.

LSTM (Long Short-Term Memory)

Designed to solve vanishing gradients with a gating mechanism:

┌─────────────────────────────────────┐
│  Cell State (cₜ) ──────────────────→ │
│       ↑         ↑                    │
│   [forget]   [input]                 │
│       ↑         ↑                    │
│  fₜ = σ(...)  iₜ = σ(...)           │
│                 ↓                    │
│              c̃ₜ = tanh(...)         │
│                                      │
│  Output: oₜ = σ(...), hₜ = oₜ×tanh(cₜ)│
└─────────────────────────────────────┘

The Three Gates

Forget Gate (fₜ): What to forget from cell state

fₜ = σ(Wf[hₜ₋₁, xₜ] + bf)

Input Gate (iₜ): What new information to add

iₜ = σ(Wi[hₜ₋₁, xₜ] + bi)
c̃ₜ = tanh(Wc[hₜ₋₁, xₜ] + bc)

Output Gate (oₜ): What to output

oₜ = σ(Wo[hₜ₋₁, xₜ] + bo)
hₜ = oₜ × tanh(cₜ)

Cell State Update

cₜ = fₜ × cₜ₋₁ + iₜ × c̃ₜ

The cell state provides a "highway" for gradients to flow unchanged!

GRU (Gated Recurrent Unit)

Simplified version of LSTM with two gates:

Update gate: zₜ = σ(Wz[hₜ₋₁, xₜ])
Reset gate:  rₜ = σ(Wr[hₜ₋₁, xₜ])

Candidate:   h̃ₜ = tanh(W[rₜ×hₜ₋₁, xₜ])
Hidden:      hₜ = (1-zₜ)×hₜ₋₁ + zₜ×h̃ₜ

LSTM vs GRU

AspectLSTMGRU
Gates3 (forget, input, output)2 (reset, update)
ParametersMoreFewer
Cell stateSeparateCombined with hidden
PerformanceOften slightly betterFaster, comparable

Bidirectional RNNs

Process sequence in both directions:

Forward:  h₁→ → h₂→ → h₃→ → h₄→
               ×
Backward: h₁← ← h₂← ← h₃← ← h₄←

Output: [h→; h←] concatenated

Captures both past and future context.

Stacked RNNs

Multiple RNN layers:

Layer 2: h₁² → h₂² → h₃² → h₄²
          ↑     ↑     ↑     ↑
Layer 1: h₁¹ → h₂¹ → h₃¹ → h₄¹
          ↑     ↑     ↑     ↑
Input:   x₁    x₂    x₃    x₄

Deeper = learns more complex patterns.

Common Applications

Many-to-One

x₁ → x₂ → x₃ → x₄ → y

Text classification, sentiment analysis

One-to-Many

x → y₁ → y₂ → y₃ → y₄

Image captioning, music generation

Many-to-Many (Aligned)

x₁ → x₂ → x₃ → x₄
↓    ↓    ↓    ↓
y₁   y₂   y₃   y₄

POS tagging, named entity recognition

Many-to-Many (Seq2Seq)

x₁ → x₂ → x₃ → [encode] → y₁ → y₂ → y₃

Translation, summarization

RNNs vs Transformers

AspectRNNsTransformers
ParallelizationSequentialFully parallel
Long-range depsHarder (even with LSTM)Easier (attention)
MemoryO(1) per stepO(n²) attention
Training speedSlowerFaster (parallel)
StreamingNaturalRequires adaptation

Modern trend: Transformers dominate most NLP. RNNs still useful for:

  • Streaming/online processing
  • Memory-constrained settings
  • When sequence order is critical

Code Example

import torch.nn as nn

# LSTM for classification
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
    
    def forward(self, x):
        x = self.embedding(x)
        _, (hidden, _) = self.lstm(x)
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)  # Bidirectional
        return self.fc(hidden)

Key Takeaways

  1. RNNs process sequences with recurrent connections
  2. Vanilla RNNs suffer from vanishing gradients
  3. LSTM uses gates and cell state to capture long-term dependencies
  4. GRU is a simpler alternative to LSTM
  5. Bidirectional RNNs capture both past and future context
  6. Transformers have largely replaced RNNs for most NLP tasks