Zubnet AI學習Wiki › Flash Attention
基礎設施

Flash Attention

FlashAttention、FlashAttention-2
一種 GPU 最佳化的注意力機制實作,比標準注意力快 2–4 倍,且顯著減少記憶體使用。Flash Attention 的實現方式不是改變注意力計算的內容,而是重新組織計算在 GPU 硬體上的執行方式——最小化 GPU HBM 和晶片上 SRAM 之間的慢速記憶體傳輸。

為什麼重要

Flash Attention 可以說是現代 AI 中影響最大的系統最佳化。它透過將注意力的記憶體使用從二次方降低到近乎線性(在實際應用中),直接使長上下文模型變得可行,使上下文窗口從 4K 跳躍到 128K 以上。每個主流 LLM 都在使用它。沒有 Flash Attention,今天的長上下文模型將會貴得令人望而卻步。

深度解析

核心洞見(Dao 等人,2022年):標準注意力在 GPU HBM(高頻寬記憶體)中實體化完整的 N×N 注意力矩陣,這既佔用大量記憶體(與序列長度呈二次方關係),又很慢(HBM 頻寬是瓶頸)。Flash Attention 從不實體化這個矩陣。相反,它以分塊方式計算注意力,將小塊的 Q、K、V 載入快速的晶片上 SRAM,計算部分結果,然後累加——一種稱為「分塊」或「核心融合」的技術。

記憶體節省

標準注意力儲存 N×N 注意力矩陣,需要 O(N²) 記憶體。對於具有 128 個注意力頭的 128K 上下文,這是數百 GB。Flash Attention 透過增量計算 softmax 且從不儲存完整矩陣,使用 O(N) 記憶體。這就是使 128K–1M 上下文窗口在現有硬體上可行的關鍵。FlashAttention-2 透過更好地跨 GPU 執行緒塊並行化,進一步提高了吞吐量。

IO 感知演算法設計

Flash Attention 體現了一個更廣泛的原則:在現代硬體上,瓶頸通常是記憶體頻寬,而不是計算能力。GPU 每秒可以執行數兆次運算,但每秒只能從 HBM 讀寫數百 GB。最小化記憶體流量的演算法(即使需要額外計算)通常更勝一籌。這種「IO 感知」的方法正在影響整個領域對 AI 工作負載演算法設計的思考方式。

相關概念

← 所有術語
← Few-Shot Learning(少樣本學習) FLOPs →