Zubnet AIAprenderWiki › Multi-Head Attention
Fundamentos

Multi-Head Attention

MHA
Ejecutar múltiples operaciones de atención en paralelo, cada una con su propia proyección aprendida de queries, keys y values. En lugar de una función de atención mirando toda la dimensión del modelo, la multi-head attention divide la dimensión en múltiples «cabezas» (p. ej. 32 cabezas de 128 dimensiones cada una para un modelo de 4096 dimensiones). Cada cabeza puede enfocarse en distintos tipos de relaciones simultáneamente.

Por qué importa

La multi-head attention es por qué los Transformers son tan expresivos. Una cabeza puede enfocarse en relaciones sintácticas (sujeto-verbo), otra en patrones posicionales (palabras cercanas), otra en similitud semántica. Esta especialización paralela permite al modelo capturar muchos tipos de dependencias a la vez, cosa que una sola cabeza de atención no puede hacer igual de bien.

Deep Dive

The mechanism: for each head i, the model learns separate projection matrices W_Q^i, W_K^i, W_V^i that project the input into a lower-dimensional space (head_dim = model_dim / num_heads). Each head independently computes attention: softmax(Q_i · K_i^T / √d) · V_i. The outputs of all heads are concatenated and projected back to the full model dimension through a final linear layer W_O.

Head Specialization

Research shows that different heads learn different functions. Some heads attend to the previous token (positional). Some attend to syntactically related tokens (subject to its verb). Some implement "induction" (pattern completion). Some attend broadly (gathering global context). Not all heads are equally important — pruning 20–40% of heads often has minimal impact on performance, suggesting significant redundancy.

GQA and MQA

Multi-Query Attention (MQA) uses a single key-value head shared across all query heads, reducing KV cache size by the number of heads. Grouped-Query Attention (GQA) is a middle ground: groups of query heads share a key-value head (e.g., 32 query heads with 8 KV heads). GQA preserves most of MHA's quality while dramatically reducing memory for KV cache. Llama 2 70B, Mistral, and most modern LLMs use GQA.

Conceptos relacionados

← Todos los términos
← Multi-Agent Systems Multimodal →