The Transformer's Long-Sequence Problem
Standard attention scales as O(N²) in both memory and compute with sequence length N. A 1M-token context with standard attention requires roughly 1 trillion attention operations per layer. Even with FlashAttention's IO efficiency, this remains compute-expensive and limits practical context lengths. Mamba (arXiv:2312.00752) proposes a fundamentally different architecture.
State Space Models: The Foundation
SSMs model sequences as continuous dynamical systems. The core recurrence is:
h'(t) = Ah(t) + Bx(t) y(t) = Ch(t)
where h is a hidden state, x is input, y is output, and A/B/C are learned matrices. Discretized to sequences, this becomes a linear recurrence that can be computed as a convolution during training (parallel, like attention) but executed as a recurrence during inference (constant compute per step, unlike attention which must attend to all previous tokens).
The Selective Mechanism in Mamba
Prior SSMs like S4 used time-invariant parameters — A, B, C were fixed regardless of the input token. This made them less effective at content-aware selection (remembering the right things and forgetting others). Mamba makes B and C input-dependent: the model learns to selectively update its hidden state based on the current input. This selectivity is what enables Mamba to match Transformer quality on language tasks.
Hardware-Parallel Scan Algorithm
Making B and C input-dependent breaks the convolution computation path. Mamba uses a parallel scan algorithm (also known as prefix sum) to compute the recurrence in O(N log N) operations in a way that is highly parallelizable on modern GPUs. The scan is implemented with custom CUDA kernels similar in spirit to FlashAttention's tiling approach.
# Simplified Mamba block structure
import torch
import torch.nn as nn
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2)
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv, groups=d_inner, padding=d_conv-1)
self.x_proj = nn.Linear(d_inner, d_state * 2 + 1) # B, C, dt
self.out_proj = nn.Linear(d_inner, d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# Apply selective SSM (simplified)
x = self.conv1d(x.transpose(1, 2)).transpose(1, 2)
return self.out_proj(x * torch.sigmoid(z))
Mamba-2 Improvements
Mamba-2 (arXiv:2405.21060) showed that the selective SSM is a special case of a broader class called structured state space duality (SSD), which connects SSMs and linear attention. Mamba-2 uses a simplified block structure (scalar instead of matrix selective parameters) that runs 2-8x faster than Mamba-1 on modern hardware.
Where Mamba Wins vs Transformers
Mamba excels at: very long sequences (>100k tokens), streaming inference (constant memory per step), tasks requiring selective state tracking (DNA sequences, audio, time series). Transformers remain superior for tasks requiring precise in-context learning and retrieval from a fixed context window.
Mamba in Production
Jamba (AI21 Labs) is a hybrid model combining Transformer layers and Mamba layers, using attention for local context and SSM for long-range dependencies. This hybrid approach captures the strengths of both architectures.