Zubnet AIApprendreWiki › GQA
Fondamentaux

GQA

Aussi appelé : Grouped Query Attention
Une variante d'attention où plusieurs têtes de queries partagent une seule tête key-value, réduisant la taille du KV cache sans réduire significativement la qualité. Au lieu que chaque tête de query ait ses propres projections K et V (MHA standard), des groupes de têtes de queries partagent les projections K et V. Llama 2 70B, Mistral, Gemma et la plupart des LLM modernes utilisent GQA.

Pourquoi c'est important

GQA est la solution pratique au problème de mémoire du KV cache. L'attention multi-tête standard avec 64 têtes nécessite 64 jeux de tenseurs K et V par couche dans le cache. GQA avec 8 têtes KV réduit cela à 8 jeux — une réduction de mémoire de 8x. Cela se traduit directement par plus d'utilisateurs simultanés ou des contextes plus longs sur le même matériel.

En profondeur

Le spectre : Multi-Head Attention (MHA) a un nombre égal de têtes Q, K, V — qualité maximale, mémoire maximale. Multi-Query Attention (MQA) a beaucoup de têtes Q mais une seule tête K et une seule tête V — mémoire minimale, une certaine perte de qualité. GQA est le juste milieu : diviser les têtes Q en groupes, chaque groupe partageant une tête K et une tête V. Un modèle avec 32 têtes Q et 8 groupes KV a chaque tête KV servant 4 têtes Q.

Qualité vs. mémoire

La recherche montre que GQA avec 8 têtes KV égale la qualité de MHA pour la plupart des tâches tout en utilisant 4–8x moins de mémoire KV cache. La préservation de la qualité est quelque peu surprenante : elle suggère que beaucoup de têtes d'attention apprennent des patterns key-value similaires, donc les partager est efficient plutôt que limitant. Convertir un modèle MHA existant en GQA par "uptraining" (une courte phase de fine-tuning) est aussi efficace, évitant le besoin de réentraîner à partir de zéro.

Impact sur l'inférence

Les économies de mémoire KV cache de GQA se traduisent directement par : des fenêtres de contexte plus longues sur le même GPU, plus de requêtes simultanées (débit plus élevé), et un calcul d'attention plus rapide (moins de tenseurs K et V à lire). Pour un modèle 70B à un contexte de 128K, la différence entre MHA et GQA peut être de centaines de gigaoctets de KV cache — la différence entre avoir besoin de 8 GPU et de 4.

Concepts connexes

← Tous les termes
← GPU Gradient Checkpointing →