Recurrent Neural Networks (RNNs)
Recurrent Neural Networks are designed for sequential data. Unlike feedforward networks, RNNs have connections that loop back, allowing them to maintain a "memory" of previous inputs.
The Basic Idea
Process sequences one element at a time, maintaining hidden state:
Input: x₁ → x₂ → x₃ → x₄
↓ ↓ ↓ ↓
Hidden: h₁ → h₂ → h₃ → h₄
↓
Output: y
The hidden state carries information from previous time steps.
Vanilla RNN
Forward Pass
hₜ = tanh(Wₓₓxₜ + Wₕₕhₜ₋₁ + b)
yₜ = Wₕᵧhₜ + bᵧ
Where:
- xₜ: input at time t
- hₜ: hidden state at time t
- yₜ: output at time t
- W: weight matrices (shared across time!)
Weight Sharing
Same weights used at every time step:
- Handles variable-length sequences
- Learns patterns regardless of position
- Fewer parameters than unrolled feedforward
The Vanishing Gradient Problem
During backpropagation through time (BPTT):
∂L/∂h₁ = ∂L/∂h₄ × ∂h₄/∂h₃ × ∂h₃/∂h₂ × ∂h₂/∂h₁
↑ ↑ ↑
Each term can be < 1
Gradients multiply at each step:
- If < 1: Gradients vanish (can't learn long-term dependencies)
- If > 1: Gradients explode (training unstable)
Vanilla RNNs struggle with sequences longer than ~10-20 steps.
LSTM (Long Short-Term Memory)
Designed to solve vanishing gradients with a gating mechanism:
┌─────────────────────────────────────┐
│ Cell State (cₜ) ──────────────────→ │
│ ↑ ↑ │
│ [forget] [input] │
│ ↑ ↑ │
│ fₜ = σ(...) iₜ = σ(...) │
│ ↓ │
│ c̃ₜ = tanh(...) │
│ │
│ Output: oₜ = σ(...), hₜ = oₜ×tanh(cₜ)│
└─────────────────────────────────────┘
The Three Gates
Forget Gate (fₜ): What to forget from cell state
fₜ = σ(Wf[hₜ₋₁, xₜ] + bf)
Input Gate (iₜ): What new information to add
iₜ = σ(Wi[hₜ₋₁, xₜ] + bi)
c̃ₜ = tanh(Wc[hₜ₋₁, xₜ] + bc)
Output Gate (oₜ): What to output
oₜ = σ(Wo[hₜ₋₁, xₜ] + bo)
hₜ = oₜ × tanh(cₜ)
Cell State Update
cₜ = fₜ × cₜ₋₁ + iₜ × c̃ₜ
The cell state provides a "highway" for gradients to flow unchanged!
GRU (Gated Recurrent Unit)
Simplified version of LSTM with two gates:
Update gate: zₜ = σ(Wz[hₜ₋₁, xₜ])
Reset gate: rₜ = σ(Wr[hₜ₋₁, xₜ])
Candidate: h̃ₜ = tanh(W[rₜ×hₜ₋₁, xₜ])
Hidden: hₜ = (1-zₜ)×hₜ₋₁ + zₜ×h̃ₜ
LSTM vs GRU
| Aspect | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (reset, update) |
| Parameters | More | Fewer |
| Cell state | Separate | Combined with hidden |
| Performance | Often slightly better | Faster, comparable |
Bidirectional RNNs
Process sequence in both directions:
Forward: h₁→ → h₂→ → h₃→ → h₄→
×
Backward: h₁← ← h₂← ← h₃← ← h₄←
Output: [h→; h←] concatenated
Captures both past and future context.
Stacked RNNs
Multiple RNN layers:
Layer 2: h₁² → h₂² → h₃² → h₄²
↑ ↑ ↑ ↑
Layer 1: h₁¹ → h₂¹ → h₃¹ → h₄¹
↑ ↑ ↑ ↑
Input: x₁ x₂ x₃ x₄
Deeper = learns more complex patterns.
Common Applications
Many-to-One
x₁ → x₂ → x₃ → x₄ → y
Text classification, sentiment analysis
One-to-Many
x → y₁ → y₂ → y₃ → y₄
Image captioning, music generation
Many-to-Many (Aligned)
x₁ → x₂ → x₃ → x₄
↓ ↓ ↓ ↓
y₁ y₂ y₃ y₄
POS tagging, named entity recognition
Many-to-Many (Seq2Seq)
x₁ → x₂ → x₃ → [encode] → y₁ → y₂ → y₃
Translation, summarization
RNNs vs Transformers
| Aspect | RNNs | Transformers |
|---|---|---|
| Parallelization | Sequential | Fully parallel |
| Long-range deps | Harder (even with LSTM) | Easier (attention) |
| Memory | O(1) per step | O(n²) attention |
| Training speed | Slower | Faster (parallel) |
| Streaming | Natural | Requires adaptation |
Modern trend: Transformers dominate most NLP. RNNs still useful for:
- Streaming/online processing
- Memory-constrained settings
- When sequence order is critical
Code Example
import torch.nn as nn
# LSTM for classification
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
x = self.embedding(x)
_, (hidden, _) = self.lstm(x)
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) # Bidirectional
return self.fc(hidden)
Key Takeaways
- RNNs process sequences with recurrent connections
- Vanilla RNNs suffer from vanishing gradients
- LSTM uses gates and cell state to capture long-term dependencies
- GRU is a simpler alternative to LSTM
- Bidirectional RNNs capture both past and future context
- Transformers have largely replaced RNNs for most NLP tasks