Why Attention Is Slow: The Memory Wall
Standard attention materializes the full N×N attention matrix in High Bandwidth Memory (HBM). For a sequence of 8192 tokens with 128-dim heads, that is 8192² = 67M entries — several GB just for intermediate results. The bottleneck is not arithmetic (FLOPs) but memory bandwidth: reading and writing to HBM is orders of magnitude slower than on-chip SRAM computation.
FlashAttention 1: IO-Aware Tiling
The original FlashAttention (arXiv:2205.14135) introduced IO-aware attention: tile the Q, K, V matrices into blocks that fit in SRAM, compute attention incrementally using the online softmax trick, and recompute activations during the backward pass instead of storing them. The full N×N matrix is never materialized. This reduced HBM reads/writes from O(N²) to O(N) and enabled 2-4x speedup with exact (not approximate) attention.
FlashAttention 2: Better Parallelism
FlashAttention 2 (arXiv:2307.08691) reduced the number of non-matmul FLOPs (which have lower throughput than matmul FLOPs on modern GPUs), parallelized across the sequence length dimension (not just batch and head), and improved thread block partitioning for better warp occupancy. Result: 2-4x speedup over FA1, reaching 50-73% of theoretical peak FLOPs/s on A100.
FlashAttention 3: H100-Specific Optimizations
FA3 (arXiv:2407.08608) targets the NVIDIA H100 Tensor Core architecture specifically. Three key advances:
1. Asynchrony and Overlapping
H100 has a dedicated asynchronous copy engine (TMA — Tensor Memory Accelerator) that can transfer data between HBM and SRAM independently of CUDA cores. FA3 overlaps compute (matmul) with data loading using software pipelining, so the GPU never stalls waiting for memory transfers.
2. FP8 Support
H100 has dedicated FP8 Tensor Cores with 2x the throughput of BF16 cores. FA3 supports FP8 attention computation with careful handling of the softmax normalization to maintain numerical stability. This enables up to 1.2 PFLOPs/s on H100 for attention.
3. Intra-Warp Parallelism
FA3 restructures the warp-level computation to better utilize H100's warp-group asynchronous execution, pipelining the two GEMMs (QK^T and attention × V) together.
# Using FlashAttention 3 via the flash-attn package
from flash_attn import flash_attn_func
import torch
# All on H100, with FP8 quantization for maximum throughput
q = torch.randn(2, 8192, 32, 128, dtype=torch.bfloat16, device="cuda")
k = torch.randn(2, 8192, 32, 128, dtype=torch.bfloat16, device="cuda")
v = torch.randn(2, 8192, 32, 128, dtype=torch.bfloat16, device="cuda")
# causal=True for autoregressive decoding
output = flash_attn_func(q, k, v, causal=True)
Impact on Context Length
FA3 makes 128k-context models practical for production inference. Each doubling of context length was previously 4x more expensive (O(N²) memory). With FA3's IO efficiency, context can be extended to 1M+ tokens in research settings without running out of memory or being bandwidth-bound.
Speedup Numbers
On H100 SXM5, FA3 achieves 1.5-2.0x speedup over FA2 in the forward pass and similar gains in the backward pass. At BF16 with sequence length 8192, FA3 reaches ~1200 TFLOPs/s versus ~700 for FA2.