Flash Attention (Dao et al. 2022, "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness") is the engineering optimization that made running large transformer models at long context lengths economically viable. Standard attention is O(n²) in memory relative to sequence length, making contexts above 4,000 tokens prohibitively expensive. Flash Attention rewrites the attention computation to be IO-aware, tiling it to avoid writing large intermediate matrices to slow GPU memory. The result: 2-4x faster training, 5-20x lower peak memory, and the technical enablement of 128k context windows that are now standard across frontier models.
The Problem Flash Attention Solves
To understand why Flash Attention matters, you need to understand the bottleneck it eliminates.
In a transformer model, attention computes a weighted sum over all positions in the sequence. For a sequence of length N, the attention matrix has N × N entries. For a 2,000-token sequence, that is 4 million entries. For a 32,000-token sequence, that is 1 billion entries.
In standard attention (implemented naively with matrix operations), this N × N attention matrix is materialized — written to and read from GPU memory — as an intermediate computation step. GPU memory has two tiers:
- SRAM (on-chip): Very fast, very small (10-40 MB on an A100)
- HBM (GPU DRAM): Slower, large (40-80 GB on an A100)
Standard attention writes the N × N matrix to HBM repeatedly during computation. At long contexts, this creates a severe memory bandwidth bottleneck. The GPU's compute units (which can do math very fast) spend most of their time waiting for data to travel from HBM, not doing useful work.
What Flash Attention Does Differently
Flash Attention's key insight is that you do not need to materialize the full attention matrix. You can compute the attention output in tiles that fit in SRAM, process each tile completely, and never write the large intermediate matrix to HBM.
The algorithm:
- Divide the queries (Q), keys (K), and values (V) into blocks that fit in SRAM
- Load one block of Q, K, V into SRAM
- Compute the partial attention output for that block
- Accumulate the partial results using an online softmax trick that allows correct partial aggregation without seeing the full matrix
- Write only the final output (not the intermediate attention matrix) back to HBM
The online softmax algorithm (the mathematically non-obvious part) allows correct computation without materializing the full N × N matrix. Dai et al. 2022 proved that this produces bit-for-bit identical outputs to standard attention — it is not an approximation.
Concrete Performance Numbers
On an A100 80GB GPU computing attention for a sequence length of 2,048 tokens:
- Standard attention: ~30 ms
- Flash Attention 2: ~8 ms
- Speedup: ~3.7x
The speedup increases with sequence length because the memory bandwidth advantage grows as the attention matrix gets larger.
Memory reduction:
- Standard attention peak memory (seq length 8,192): ~4 GB
- Flash Attention peak memory (seq length 8,192): ~800 MB
- Reduction: ~5x
At 128,000-token contexts:
- Standard attention would require ~260 GB of memory (impossible on any single GPU)
- Flash Attention: ~8-12 GB — fits in an A100 80GB with room for model weights
This is the technical enablement behind 128k context windows. Without Flash Attention, models with long context would simply not run on available hardware.
Flash Attention 2 and 3
Flash Attention 2 (Dao 2023) improved on the original by better parallelizing the work across GPU threads, achieving 2-3x higher throughput than Flash Attention 1. Most training and inference frameworks now use Flash Attention 2 as the default.
Flash Attention 3 (2024, targeting Hopper architecture H100 GPUs) uses asynchronous computation to overlap memory loads with compute, achieving 1.5-2x further improvement over FA2 on H100 hardware.
How Flash Attention Reduces Training Cost
Training cost is roughly proportional to compute time. A 3.7x speedup in attention — the most computationally intensive operation during training — translates to roughly a 2x speedup in end-to-end training.
For a model that would cost $1 million to train with standard attention, Flash Attention roughly halves that to $500,000. At the scale of frontier model training (costs in the hundreds of millions of dollars), this is a significant saving.
Flash Attention in Practice
Flash Attention is implemented in the major deep learning frameworks:
PyTorch: Flash Attention is available via torch.nn.functional.scaled_dot_product_attention, which automatically uses Flash Attention when hardware supports it.
import torch
import torch.nn.functional as F
# This automatically uses Flash Attention when available
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True
)
Hugging Face Transformers: Set attn_implementation="flash_attention_2" when loading a model:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
)
vLLM: Uses Flash Attention by default on supported hardware. No configuration required.
Requirement: Flash Attention requires a CUDA GPU with compute capability >= 7.0 (Volta architecture and newer). It does not run on CPU or older GPUs.
Why Flash Attention Matters for LLM Cost
Flash Attention's impact on inference cost is primarily through enabling long-context models on more modest hardware. Before Flash Attention, serving a model with 32k+ context required multiple GPUs for memory capacity alone. With Flash Attention, those models run on a single GPU.
The compounding effect: Flash Attention enabled long context windows, which reduced the need for chunked document processing. Chunked processing splits a long document into multiple requests and recombines results — expensive both in compute and API calls. A single long-context request with Flash Attention is cheaper than five short-context requests processing chunks.
Keep Reading
- Speculative Decoding Explained — Another inference optimization that compounds with Flash Attention.
- Quantization for Inference Cost — Combine with Flash Attention for maximum hardware efficiency.
- Local LLM vs. API Cost Comparison — How these optimizations affect the break-even calculation for self-hosted models.
Pristren builds AI-powered software for teams. Zlyqor is our all-in-one workspace — chat, projects, time tracking, AI meeting summaries, and invoicing — in one tool. Try it free.