LSTM Networks
Long Short-Term Memory (LSTM) networks are a special type of recurrent neural network designed to learn long-term dependencies in sequential data.
The Problem with Vanilla RNNs
Standard RNNs suffer from:
- Vanishing gradients: Gradients shrink exponentially over long sequences
- Short-term memory: Difficulty learning dependencies beyond ~10-20 timesteps
LSTM Architecture
LSTMs introduce a cell state and gating mechanisms:
┌─────────────────────────────────────┐
│ Cell State (Ct) │
│ ────────────────────────────────► │
│ × + │
│ │ │ │
│ ┌──┴──┐ ┌────┴────┐ ┌──────┐ │
│ │Forget│ │ Input │ │Output│ │
│ │ Gate │ │ Gate │ │ Gate │ │
│ └──────┘ └─────────┘ └──────┘ │
│ ft it ot │
└─────────────────────────────────────┘
The Three Gates
1. Forget Gate
Decides what information to discard from the cell state:
ft = σ(Wf · [ht-1, xt] + bf)
- Output: values between 0 (forget) and 1 (keep)
- Example: Forget old subject when new sentence starts
2. Input Gate
Decides what new information to store:
it = σ(Wi · [ht-1, xt] + bi) # What to update
C̃t = tanh(Wc · [ht-1, xt] + bc) # Candidate values
3. Output Gate
Decides what to output based on cell state:
ot = σ(Wo · [ht-1, xt] + bo)
ht = ot × tanh(Ct)
Cell State Update
The cell state is updated as:
Ct = ft × Ct-1 + it × C̃t
This allows gradients to flow unchanged through the cell state (the "highway").
Why LSTMs Work
Gradient Flow
Vanilla RNN: gradient = ∏ Wh (multiplicative)
LSTM: gradient = ∏ ft (additive through cell state)
The additive nature of cell state updates allows gradients to flow over long sequences.
Selective Memory
- Forget gate: Remove irrelevant information
- Input gate: Add new relevant information
- Output gate: Control what affects current output
LSTM vs GRU
| Feature | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (reset, update) |
| Cell state | Separate | Combined with hidden |
| Parameters | More | Fewer |
| Training | Slower | Faster |
| Performance | Often similar | Often similar |
GRU Equations
zt = σ(Wz · [ht-1, xt]) # Update gate
rt = σ(Wr · [ht-1, xt]) # Reset gate
h̃t = tanh(W · [rt × ht-1, xt]) # Candidate
ht = (1-zt) × ht-1 + zt × h̃t # Output
Bidirectional LSTM
Process sequences in both directions:
Forward: → → → →
Backward: ← ← ← ←
Output: Concatenate or add
import torch.nn as nn
bilstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=2,
bidirectional=True
)
# Output size: 256 × 2 = 512
Stacked LSTMs
Multiple LSTM layers for learning hierarchical representations:
stacked_lstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=3, # 3 stacked layers
dropout=0.2 # Dropout between layers
)
Common Applications
- Language Modeling: Predict next word
- Machine Translation: Seq2seq with attention
- Speech Recognition: Audio to text
- Sentiment Analysis: Text classification
- Time Series: Forecasting, anomaly detection
Implementation Example
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim,
hidden_dim,
batch_first=True,
bidirectional=True
)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
embedded = self.embedding(x) # (batch, seq, embed)
output, (hidden, cell) = self.lstm(embedded)
# Use final hidden states from both directions
hidden_cat = torch.cat([hidden[-2], hidden[-1]], dim=1)
return self.fc(hidden_cat)
Limitations
- Sequential processing: Can't parallelize across time steps
- Still limited memory: Very long sequences remain challenging
- Superseded by Transformers: For most NLP tasks
When to Use LSTMs
- Sequential data with moderate-length dependencies
- Limited computational resources
- Time series forecasting
- When attention isn't needed
Key Takeaways
- LSTMs solve vanishing gradients with gating mechanisms
- Cell state acts as a "memory highway" for gradient flow
- Three gates control information flow: forget, input, output
- Bidirectional LSTMs capture context from both directions
- GRUs are a simpler alternative with similar performance
- Transformers have largely replaced LSTMs for NLP