El mecanismo: para cada cabeza i, el modelo aprende matrices de proyección separadas W_Q^i, W_K^i, W_V^i que proyectan la entrada en un espacio de menor dimensión (head_dim = model_dim / num_heads). Cada cabeza calcula atención de forma independiente: softmax(Q_i · K_i^T / √d) · V_i. Las salidas de todas las cabezas se concatenan y se proyectan de vuelta a la dimensión completa del modelo mediante una capa lineal final W_O.
La investigación muestra que diferentes cabezas aprenden funciones diferentes. Algunas atienden al token anterior (posicional). Algunas atienden a tokens sintácticamente relacionados (sujeto a su verbo). Algunas implementan "inducción" (completar patrones). Algunas atienden ampliamente (recopilando contexto global). No todas las cabezas son igualmente importantes — podar del 20–40% de las cabezas a menudo tiene un impacto mínimo en el rendimiento, lo que sugiere una redundancia significativa.
Multi-Query Attention (MQA) usa una sola cabeza key-value compartida entre todas las cabezas de query, reduciendo el tamaño del KV cache por el número de cabezas. Grouped-Query Attention (GQA) es un término medio: grupos de cabezas de query comparten una cabeza key-value (por ejemplo, 32 cabezas de query con 8 cabezas KV). GQA preserva la mayor parte de la calidad de MHA mientras reduce dramáticamente la memoria para el KV cache. Llama 2 70B, Mistral y la mayoría de los LLMs modernos usan GQA.