Zubnet AILearnWiki › Multi-Head Attention
Fundamentals

Multi-Head Attention

MHA
Running multiple attention operations in parallel, each with its own learned projection of the queries, keys, and values. Instead of one attention function looking at the full model dimension, multi-head attention splits the dimension into multiple "heads" (e.g., 32 heads of 128 dimensions each for a 4096-dimension model). Each head can focus on different types of relationships simultaneously.

Why it matters

Multi-head attention is why Transformers are so expressive. One head might focus on syntactic relationships (subject-verb), another on positional patterns (nearby words), another on semantic similarity. This parallel specialization lets the model capture many types of dependencies simultaneously, which a single attention head can't do as effectively.

Deep Dive

The mechanism: for each head i, the model learns separate projection matrices W_Q^i, W_K^i, W_V^i that project the input into a lower-dimensional space (head_dim = model_dim / num_heads). Each head independently computes attention: softmax(Q_i · K_i^T / √d) · V_i. The outputs of all heads are concatenated and projected back to the full model dimension through a final linear layer W_O.

Head Specialization

Research shows that different heads learn different functions. Some heads attend to the previous token (positional). Some attend to syntactically related tokens (subject to its verb). Some implement "induction" (pattern completion). Some attend broadly (gathering global context). Not all heads are equally important — pruning 20–40% of heads often has minimal impact on performance, suggesting significant redundancy.

GQA and MQA

Multi-Query Attention (MQA) uses a single key-value head shared across all query heads, reducing KV cache size by the number of heads. Grouped-Query Attention (GQA) is a middle ground: groups of query heads share a key-value head (e.g., 32 query heads with 8 KV heads). GQA preserves most of MHA's quality while dramatically reducing memory for KV cache. Llama 2 70B, Mistral, and most modern LLMs use GQA.

Related Concepts

← All Terms
← Multi-Agent Systems Multimodal →