MIT MLSys Discussion Group

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

Summary. FlashAttention is an exact optimization to the original attention module. The core idea is to compute the NxN (N=sequence len) attention matrix in small tiles such that each tile easily fit within the fast but small memory (SRAM) on GPU. The benefit is that 1) doing so reduces access to the slow but large memory (HBM) thus improving runtime and 2) the full attention is never fully materialized thus improving memory efficiency.