intermediateDeep Learning

Understand Long Short-Term Memory networks, a type of RNN designed to learn long-range dependencies and avoid vanishing gradients.

rnnlstmsequence-modelingdeep-learningtime-series

LSTM Networks

Long Short-Term Memory (LSTM) networks are a special type of recurrent neural network designed to learn long-term dependencies in sequential data.

The Problem with Vanilla RNNs

Standard RNNs suffer from:

  • Vanishing gradients: Gradients shrink exponentially over long sequences
  • Short-term memory: Difficulty learning dependencies beyond ~10-20 timesteps

LSTM Architecture

LSTMs introduce a cell state and gating mechanisms:

         ┌─────────────────────────────────────┐
         │           Cell State (Ct)            │
         │  ────────────────────────────────►   │
         │     ×           +                    │
         │     │           │                    │
         │  ┌──┴──┐   ┌────┴────┐   ┌──────┐   │
         │  │Forget│   │  Input  │   │Output│   │
         │  │ Gate │   │  Gate   │   │ Gate │   │
         │  └──────┘   └─────────┘   └──────┘   │
         │     ft          it          ot       │
         └─────────────────────────────────────┘

The Three Gates

1. Forget Gate

Decides what information to discard from the cell state:

ft = σ(Wf · [ht-1, xt] + bf)
  • Output: values between 0 (forget) and 1 (keep)
  • Example: Forget old subject when new sentence starts

2. Input Gate

Decides what new information to store:

it = σ(Wi · [ht-1, xt] + bi)      # What to update
C̃t = tanh(Wc · [ht-1, xt] + bc)   # Candidate values

3. Output Gate

Decides what to output based on cell state:

ot = σ(Wo · [ht-1, xt] + bo)
ht = ot × tanh(Ct)

Cell State Update

The cell state is updated as:

Ct = ft × Ct-1 + it × C̃t

This allows gradients to flow unchanged through the cell state (the "highway").

Why LSTMs Work

Gradient Flow

Vanilla RNN: gradient = ∏ Wh (multiplicative)
LSTM:        gradient = ∏ ft  (additive through cell state)

The additive nature of cell state updates allows gradients to flow over long sequences.

Selective Memory

  • Forget gate: Remove irrelevant information
  • Input gate: Add new relevant information
  • Output gate: Control what affects current output

LSTM vs GRU

FeatureLSTMGRU
Gates3 (forget, input, output)2 (reset, update)
Cell stateSeparateCombined with hidden
ParametersMoreFewer
TrainingSlowerFaster
PerformanceOften similarOften similar

GRU Equations

zt = σ(Wz · [ht-1, xt])           # Update gate
rt = σ(Wr · [ht-1, xt])           # Reset gate  
h̃t = tanh(W · [rt × ht-1, xt])   # Candidate
ht = (1-zt) × ht-1 + zt × h̃t     # Output

Bidirectional LSTM

Process sequences in both directions:

Forward:   → → → →
Backward:  ← ← ← ←
Output:    Concatenate or add
import torch.nn as nn

bilstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=2,
    bidirectional=True
)
# Output size: 256 × 2 = 512

Stacked LSTMs

Multiple LSTM layers for learning hierarchical representations:

stacked_lstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=3,      # 3 stacked layers
    dropout=0.2        # Dropout between layers
)

Common Applications

  1. Language Modeling: Predict next word
  2. Machine Translation: Seq2seq with attention
  3. Speech Recognition: Audio to text
  4. Sentiment Analysis: Text classification
  5. Time Series: Forecasting, anomaly detection

Implementation Example

import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, 
            hidden_dim, 
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
        
    def forward(self, x):
        embedded = self.embedding(x)  # (batch, seq, embed)
        output, (hidden, cell) = self.lstm(embedded)
        
        # Use final hidden states from both directions
        hidden_cat = torch.cat([hidden[-2], hidden[-1]], dim=1)
        return self.fc(hidden_cat)

Limitations

  1. Sequential processing: Can't parallelize across time steps
  2. Still limited memory: Very long sequences remain challenging
  3. Superseded by Transformers: For most NLP tasks

When to Use LSTMs

  • Sequential data with moderate-length dependencies
  • Limited computational resources
  • Time series forecasting
  • When attention isn't needed

Key Takeaways

  1. LSTMs solve vanishing gradients with gating mechanisms
  2. Cell state acts as a "memory highway" for gradient flow
  3. Three gates control information flow: forget, input, output
  4. Bidirectional LSTMs capture context from both directions
  5. GRUs are a simpler alternative with similar performance
  6. Transformers have largely replaced LSTMs for NLP