Hand-drawn jungle scene with soft green tones
1 min read

Flash Attention Explained

A practical walkthrough of how Flash Attention reduces memory traffic and speeds up transformer training.

Standard attention materializes the full N×NN \times N attention matrix in HBM, which becomes the bottleneck for long sequences. Flash Attention reformulates attention as a tiled, IO-aware algorithm that keeps partial results in SRAM.

The memory wall

GPU compute has outpaced memory bandwidth for years. Attention is particularly brutal:

Memory=O(N2)for sequence length N\text{Memory} = O(N^2) \quad \text{for sequence length } N

Tiling strategy

Flash Attention divides Q, K, and V into blocks and fuses the softmax with the matrix multiply:

# Simplified conceptual loop
for block in range(num_blocks):
    q_block = load_q(block)
    for kv_block in range(num_kv_blocks):
        k_block, v_block = load_kv(kv_block)
        scores = q_block @ k_block.T
        attn = softmax(scores)
        output += attn @ v_block

Why it matters

  • Training: fit longer contexts without proportional HBM growth
  • Inference: lower latency for serving workloads
  • Downstream: enables PagedAttention and modern LLM serving stacks

Further reading