Skip to main content

Posts

Showing posts with the label FlashAttention

FlashAttention: High-Speed, Memory-Efficient Attention for Transformers

Transformers have become the standard architecture in NLP and vision, but the quadratic complexity of attention in both computation and memory makes it a bottleneck for long sequences. FlashAttention , introduced in 2022, proposes a memory-aware exact attention algorithm that significantly boosts performance by optimizing GPU memory usage. 2. Bottlenecks in Standard Attention The classic attention operation is defined as: $Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d}}) V$ Here, intermediate results like the $QK^T$ matrix are materialized and stored in GPU HBM, leading to extensive memory I/O and $O(n^2)$ memory consumption, which severely limits sequence length and throughput. 3. Core Ideas of FlashAttention 1 FlashAttention rethinks the attention operation with the following ideas: Tile-based streaming computation : Avoids storing $QK^T$ by breaking computations into tiles and using GPU SRAM and registers. Online softmax accumulation : Uses a streaming algorithm t...