Summary. FlashAttention is an exact optimization to the original attention module. The core idea is to compute the NxN (N=sequence len) attention matrix in small tiles such that each tile easily fit within the fast but small memory (SRAM) on GPU. The benefit is that 1) doing so reduces access to the slow but large memory (HBM) thus improving runtime and 2) the full attention is never fully materialized thus improving memory efficiency.
- Though tiling is a commonplace optimization its application here is however not at all straightforward. The main challenge here actually arises from a critical implementation detail that is often absent from the mathematical description of the attention module. Specifically, for numerical stability, in practice, softmax for a vector of logits is calculated as:
m=max(z)
Softmax[z]_i=e^(z_i-m)/(sum_j(e^(z_i-m)))
The main idea is to avoid exponentiating very large numbers while taking advantage of the translation-invariance of softmax operation. If we take a closer look at this formula we’ll see a nested set of 2 reductions over vector z, one to compute the max of z, the other to compute sum of exponentials.
- Partitioning/tiling these two reductions simultaneously is non-trivial due to the dependence of sum on the result of max, and is therefore outside the purview of a traditional optimizing compiler. We speculate that this is the primary reason why flash attention is hard/difficult and thus not already done despite its simple/intuitive motivation.
- In this work, we have seen that the objectives to achieve high performance and good numerical stability interact unfavorably with one another, making it hard to write performant and numerically stable code at the same time. Existing auto-schedulers (Halide, TVM) often do not consider optimizing for numerical stability. Can we introduce numerical stability as an explicit goal to exiting auto scheduling systems?