Zubnet AI學習Wiki › GQA
基礎

GQA

別名:分組查詢注意力

一種注意力變體,多個查詢頭共享單一的鍵值頭,減少 KV 快取大小而不顯著降低品質。GQA 不是每個查詢頭都有自己的 K 和 V 投影(標準 MHA),而是讓一組查詢頭共享 K 和 V 投影。Llama 2 70B、Mistral、Gemma 及大多數現代 LLM 使用 GQA。

為什麼重要

GQA 是 KV 快取記憶體問題的實用解決方案。標準多頭注意力使用 64 個頭需要每層在快取中儲存 64 組 K 和 V 張量。使用 8 個 KV 頭的 GQA 將此減少到 8 組 — 記憶體減少了 8 倍。這直接轉化為在相同硬體上服務更多併發使用者或處理更長的上下文。

深度解析

一個頻譜:多頭注意力(MHA)擁有相等數量的 Q、K、V 頭 — 最高品質,最大記憶體。多查詢注意力(MQA)有許多 Q 頭但只有一個 K 和一個 V 頭 — 最小記憶體,有些品質損失。GQA 是折衷方案:將 Q 頭分成組,每組共享一個 K 和一個 V 頭。一個有 32 個 Q 頭和 8 個 KV 組的模型,每個 KV 頭服務 4 個 Q 頭。

品質 vs. 記憶體

研究表明,使用 8 個 KV 頭的 GQA 在大多數任務上匹配 MHA 的品質,同時使用少 4–8 倍的 KV 快取記憶體。品質的保持有些令人驚訝:這暗示許多注意力頭正在學習相似的鍵值模式,因此共享它們是高效的而非限制性的。透過「再訓練」(短期微調階段)將現有的 MHA 模型轉換為 GQA 也是有效的,避免了從頭開始重新訓練的需要。

對推論的影響

GQA 帶來的 KV 快取記憶體節省直接轉化為:在相同 GPU 上有更長的上下文視窗、更多的併發請求(更高的吞吐量)、以及更快的注意力計算(需要讀取的 K 和 V 張量更少)。對於 128K 上下文的 70B 模型,MHA 和 GQA 之間的差異可能是數百 GB 的 KV 快取 — 也就是需要 8 個 GPU 和需要 4 個 GPU 之間的差異。

相關概念

← 所有術語
← GPU Gradient Descent(梯度下降) →