Zubnet AIApprendreWiki › Multi-Head Attention
Fondamentaux

Multi-Head Attention

MHA
Exécuter plusieurs opérations d'attention en parallèle, chacune avec sa propre projection apprise des queries, keys et values. Au lieu d'une seule fonction d'attention qui regarde toute la dimension du modèle, le multi-head attention divise la dimension en plusieurs « têtes » (ex. 32 têtes de 128 dimensions chacune pour un modèle de 4096 dimensions). Chaque tête peut se concentrer sur différents types de relations simultanément.

Pourquoi c'est important

Le multi-head attention est pourquoi les Transformers sont si expressifs. Une tête peut se concentrer sur des relations syntaxiques (sujet-verbe), une autre sur des patterns positionnels (mots proches), une autre sur la similarité sémantique. Cette spécialisation parallèle permet au modèle de capturer beaucoup de types de dépendances simultanément, ce qu'une seule tête d'attention ne peut pas faire aussi bien.

Deep Dive

The mechanism: for each head i, the model learns separate projection matrices W_Q^i, W_K^i, W_V^i that project the input into a lower-dimensional space (head_dim = model_dim / num_heads). Each head independently computes attention: softmax(Q_i · K_i^T / √d) · V_i. The outputs of all heads are concatenated and projected back to the full model dimension through a final linear layer W_O.

Head Specialization

Research shows that different heads learn different functions. Some heads attend to the previous token (positional). Some attend to syntactically related tokens (subject to its verb). Some implement "induction" (pattern completion). Some attend broadly (gathering global context). Not all heads are equally important — pruning 20–40% of heads often has minimal impact on performance, suggesting significant redundancy.

GQA and MQA

Multi-Query Attention (MQA) uses a single key-value head shared across all query heads, reducing KV cache size by the number of heads. Grouped-Query Attention (GQA) is a middle ground: groups of query heads share a key-value head (e.g., 32 query heads with 8 KV heads). GQA preserves most of MHA's quality while dramatically reducing memory for KV cache. Llama 2 70B, Mistral, and most modern LLMs use GQA.

Concepts liés

← Tous les termes
← Multi-Agent Systems Multimodal →