Zubnet AI學習Wiki › Multi-Head Attention
基礎

Multi-Head Attention

MHA
平行執行多個 attention 操作,每個都有自己學到的 queries、keys、values 投影。不是一個 attention 函數看完整的模型維度,multi-head attention 把維度分成多個「頭」(比如 4096 維的模型用 32 個頭,每個 128 維)。每個頭可以同時關注不同類型的關係。

為什麼重要

Multi-head attention 是 Transformer 如此富有表達力的原因。一個頭可能關注句法關係(主詞-動詞),另一個關注位置模式(鄰近詞),另一個關注語意相似度。這種平行的專門化讓模型能同時捕捉多種依賴,單個 attention 頭做不到這麼好。

Deep Dive

The mechanism: for each head i, the model learns separate projection matrices W_Q^i, W_K^i, W_V^i that project the input into a lower-dimensional space (head_dim = model_dim / num_heads). Each head independently computes attention: softmax(Q_i · K_i^T / √d) · V_i. The outputs of all heads are concatenated and projected back to the full model dimension through a final linear layer W_O.

Head Specialization

Research shows that different heads learn different functions. Some heads attend to the previous token (positional). Some attend to syntactically related tokens (subject to its verb). Some implement "induction" (pattern completion). Some attend broadly (gathering global context). Not all heads are equally important — pruning 20–40% of heads often has minimal impact on performance, suggesting significant redundancy.

GQA and MQA

Multi-Query Attention (MQA) uses a single key-value head shared across all query heads, reducing KV cache size by the number of heads. Grouped-Query Attention (GQA) is a middle ground: groups of query heads share a key-value head (e.g., 32 query heads with 8 KV heads). GQA preserves most of MHA's quality while dramatically reducing memory for KV cache. Llama 2 70B, Mistral, and most modern LLMs use GQA.

相關概念

← 所有術語
← Multi-Agent Systems Multimodal →