Zubnet AI學習Wiki › Flash Attention
基礎設施

Flash Attention

FlashAttention, FlashAttention-2
注意力機制的 GPU 優化實作,比標準注意力快 2–4 倍,使用的記憶體顯著更少。Flash Attention 實現這一點不是透過改變注意力運算什麼,而是重新組織 GPU 硬體上的運算方式 — 最小化 GPU HBM 和晶片上 SRAM 之間的慢速記憶體傳輸。

為什麼重要

Flash Attention 可以說是現代 AI 最具影響力的系統優化。它讓長上下文模型變得實用,把注意力的記憶體使用從平方級降到(實際上)近乎線性,直接促成從 4K 到 128K+ 上下文視窗的跳躍。每個主要 LLM 都用它。沒有 Flash Attention,今天的長上下文模型會貴得不可行。

Deep Dive

The key insight (Dao et al., 2022): standard attention materializes the full N×N attention matrix in GPU HBM (high bandwidth memory), which is both memory-intensive (quadratic in sequence length) and slow (HBM bandwidth is the bottleneck). Flash Attention never materializes this matrix. Instead, it computes attention in tiles, loading small blocks of Q, K, V into fast on-chip SRAM, computing partial results, and accumulating them — a technique called "tiling" or "kernel fusion."

The Memory Savings

Standard attention stores the N×N attention matrix, requiring O(N²) memory. For a 128K context with 128 attention heads, that's hundreds of gigabytes. Flash Attention uses O(N) memory by computing softmax incrementally and never storing the full matrix. This is what made 128K–1M context windows feasible on existing hardware. FlashAttention-2 further improved throughput by better parallelizing across GPU thread blocks.

IO-Aware Algorithm Design

Flash Attention exemplifies a broader principle: on modern hardware, the bottleneck is often memory bandwidth, not compute. GPUs can perform trillions of operations per second but can only read/write hundreds of gigabytes per second from HBM. Algorithms that minimize memory traffic (even at the cost of extra computation) often win. This "IO-aware" approach is influencing how the entire field thinks about algorithm design for AI workloads.

相關概念

← 所有術語
← Fine-tuning FLOPs →