核心洞見(Dao 等人,2022年):標準注意力在 GPU HBM(高頻寬記憶體)中實體化完整的 N×N 注意力矩陣,這既佔用大量記憶體(與序列長度呈二次方關係),又很慢(HBM 頻寬是瓶頸)。Flash Attention 從不實體化這個矩陣。相反,它以分塊方式計算注意力,將小塊的 Q、K、V 載入快速的晶片上 SRAM,計算部分結果,然後累加——一種稱為「分塊」或「核心融合」的技術。
標準注意力儲存 N×N 注意力矩陣,需要 O(N²) 記憶體。對於具有 128 個注意力頭的 128K 上下文,這是數百 GB。Flash Attention 透過增量計算 softmax 且從不儲存完整矩陣,使用 O(N) 記憶體。這就是使 128K–1M 上下文窗口在現有硬體上可行的關鍵。FlashAttention-2 透過更好地跨 GPU 執行緒塊並行化,進一步提高了吞吐量。
Flash Attention 體現了一個更廣泛的原則:在現代硬體上,瓶頸通常是記憶體頻寬,而不是計算能力。GPU 每秒可以執行數兆次運算,但每秒只能從 HBM 讀寫數百 GB。最小化記憶體流量的演算法(即使需要額外計算)通常更勝一籌。這種「IO 感知」的方法正在影響整個領域對 AI 工作負載演算法設計的思考方式。