Le mécanisme : pour chaque tête i, le modèle apprend des matrices de projection séparées W_Q^i, W_K^i, W_V^i qui projettent l'entrée dans un espace de dimension inférieure (head_dim = model_dim / num_heads). Chaque tête calcule indépendamment l'attention : softmax(Q_i · K_i^T / √d) · V_i. Les sorties de toutes les têtes sont concaténées et projetées dans la dimension complète du modèle via une couche linéaire finale W_O.
La recherche montre que différentes têtes apprennent différentes fonctions. Certaines têtes portent attention au token précédent (positionnel). Certaines portent attention aux tokens syntaxiquement liés (sujet vers son verbe). Certaines implémentent l'"induction" (complétion de patterns). Certaines portent attention de manière large (rassemblement du contexte global). Toutes les têtes ne sont pas également importantes — élaguer 20–40% des têtes a souvent un impact minimal sur la performance, suggérant une redondance significative.
Multi-Query Attention (MQA) utilise une seule tête key-value partagée entre toutes les têtes de queries, réduisant la taille du KV cache du nombre de têtes. Grouped-Query Attention (GQA) est un compromis : des groupes de têtes de queries partagent une tête key-value (ex : 32 têtes de queries avec 8 têtes KV). GQA préserve la majeure partie de la qualité de MHA tout en réduisant dramatiquement la mémoire pour le KV cache. Llama 2 70B, Mistral et la plupart des LLM modernes utilisent GQA.