1 min read
Flash Attention Explained
A practical walkthrough of how Flash Attention reduces memory traffic and speeds up transformer training.
Standard attention materializes the full attention matrix in HBM, which becomes the bottleneck for long sequences. Flash Attention reformulates attention as a tiled, IO-aware algorithm that keeps partial results in SRAM.
The memory wall
GPU compute has outpaced memory bandwidth for years. Attention is particularly brutal:
Tiling strategy
Flash Attention divides Q, K, and V into blocks and fuses the softmax with the matrix multiply:
# Simplified conceptual loop
for block in range(num_blocks):
q_block = load_q(block)
for kv_block in range(num_kv_blocks):
k_block, v_block = load_kv(kv_block)
scores = q_block @ k_block.T
attn = softmax(scores)
output += attn @ v_block
Why it matters
- Training: fit longer contexts without proportional HBM growth
- Inference: lower latency for serving workloads
- Downstream: enables PagedAttention and modern LLM serving stacks