Paper Notes: FlashAttention

April 1, 2026·
Jiangneng Li
Jiangneng Li
· 4 min read
post Paper Notes

Paper: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)

1. Tiling and Safe Online Softmax (The Forward Pass Math)

The fundamental bottleneck of standard attention is the $\Theta(N^2)$ memory requirement to materialize the attention score matrix in High Bandwidth Memory (HBM). FlashAttention solves this via tiling (computing block by block in SRAM) combined with the Safe Online Softmax mathematical trick.

The Overflow Problem & Safe Softmax

Standard softmax operations $e^{x_i} / \sum e^{x_j}$ will trigger numerical overflow (e.g., NaN in FP16) if $x_i$ is large. To prevent this, a local maximum $m(x) = \max_i x_i$ is subtracted from all elements:

$$\text{softmax}(x_i) = \frac{e^{x_i - m(x)}}{\sum_j e^{x_j - m(x)}}$$

The “Time-Travel” Reweighting Trick (Online Softmax)

Because blocks are processed sequentially and earlier blocks are discarded from SRAM, we cannot retroactively subtract a newly discovered global maximum from old blocks. Instead, FlashAttention leverages the exponential property $e^{a-b} = e^a \cdot e^{-b}$ to dynamically “decay” historical running states.

For each new block $j$, the GPU computes the local scores $S_j = Q K_j^T$, the local max $m_{local}$, and the local exponentiated values $\tilde{P}_{local} = \exp(S_j - m_{local})$. The running variables are updated entirely in SRAM:

Update Global Max:

$$m_{new} = \max(m_{old}, m_{local})$$

Update Running Denominator ($l$) via Exponential Decay:

$$l_{new} = l_{old} \cdot \exp(m_{old} - m_{new}) + \text{rowsum}(\tilde{P}_{local})$$

Update Running Numerator/Output ($O$) via Weighted Sum:

$$O_{new} = O_{old} \cdot \exp(m_{old} - m_{new}) + \tilde{P}_{local} V_{local}$$

By applying the decay factor $\exp(m_{old} - m_{new})$ to the history, the algorithm mathematically aligns all previous calculations to the new maximum without ever reloading old $K$ and $V$ matrices.

2. Loop Order Optimization: FlashAttention-1 vs. FlashAttention-2

The physical execution speed of GPU kernels is heavily bound by HBM write operations.

FlashAttention-1 (KV Outer Loop, Q Inner Loop): FA1 iterates over $K, V$ blocks in the outer loop. For every inner loop step over $Q$, the intermediate, partially accumulated output block $O_i$ must be read from HBM, “un-normalized” by multiplying the old denominator, updated with the new block’s weighted sum, re-normalized, and written back to HBM. This causes a massive $O_i$ read/write overhead.

FlashAttention-2 (Q Outer Loop, KV Inner Loop): FA2 pins a $Q_i$ block in the outer loop and iterates through all $K_j, V_j$ blocks in the inner loop. The running variables $O_{run}$, $m_{run}$, and $l_{run}$ stay exclusively inside the SRAM registers. The intermediate $O_i$ is continuously accumulated using the decay formula and is written to HBM exactly once after the entire inner KV loop finishes. This simple loop swap eliminates the repetitive HBM writes, drastically dropping the constant factor in the $O(N^2 d^2 M^{-1})$ complexity.

3. The Backward Pass and Gradient Recomputation

During model training, the backward pass requires the full $N \times N$ attention probability matrix $P$ to calculate gradients using the Chain Rule. Writing this massive matrix to HBM during the forward pass would negate all memory optimizations.

Checkpointing Global Statistics

Instead of storing the $N \times N$ matrix, the forward pass only saves the final global scalars to HBM: the global maximum ($m^{global}$) and the global denominator ($l^{global}$).

On-the-Fly Recomputation and Matrix Calculus

During the backward pass, the GPU loads $Q_i$, $K_j$, $V_j$, and the upstream gradient $dO_i$ into SRAM. Because the true global maximum is already known, there is no need for dynamic reweighting. The exact local probability block $P_{ij}$ is reconstructed instantly:

$$P_{ij} = \frac{\exp(Q_i K_j^T - m^{global})}{l^{global}}$$

With $P_{ij}$ reconstructed locally, the gradients are computed using the multivariable chain rule, and the results are accumulated (+=):

Gradient of V:

$$dV_j \mathrel{+}= P_{ij}^T \cdot dO_i$$

Gradient of Pre-Softmax Scores ($S$):

$$dS_{ij} = P_{ij} \circ (dO_i \cdot V_j^T - D_i)$$

(where $\circ$ is element-wise multiplication, and $D_i = \text{rowsum}(dO_i \circ O_i)$)

Gradients of Q and K:

$$dQ_i \mathrel{+}= dS_{ij} \cdot K_j$$

$$dK_j \mathrel{+}= dS_{ij}^T \cdot Q_i$$

The strict accumulation logic (+=) represents the physical manifestation of the mathematical summation over all blocks. Once the local gradients are added to the accumulators in HBM, the massive $P_{ij}$ and $dS_{ij}$ blocks are immediately destroyed from SRAM, ensuring the memory footprint remains constant $\Theta(1)$ regardless of sequence length.

Jiangneng Li
Authors
Doctor of Philosophy
PhD at NTU researching database systems, Data+AI, and multimedia data analytics.