Zubnet AI学习Wiki › 多头注意力
基础

多头注意力

别名:MHA
并行运行多个注意力操作,每个操作都有自己学习到的查询、键和值的投影。多头注意力不是用一个注意力函数处理完整的模型维度,而是将维度分成多个"头"(例如,4096维模型使用32个头,每个128维)。每个头可以同时关注不同类型的关系。

为什么重要

多头注意力是Transformer如此强大的原因。一个头可能关注句法关系(主谓),另一个关注位置模式(相邻词),再一个关注语义相似性。这种并行的专业化使模型能够同时捕获多种类型的依赖关系,这是单个注意力头无法有效做到的。

深度解析

机制:对于每个头i,模型学习独立的投影矩阵W_Q^i、W_K^i、W_V^i,将输入投影到低维空间(head_dim = model_dim / num_heads)。每个头独立计算注意力:softmax(Q_i · K_i^T / √d) · V_i。所有头的输出被拼接并通过最终的线性层W_O投影回完整的模型维度。

头的专业化

研究表明不同的头会学习不同的功能。有些头关注前一个token(位置性的)。有些关注句法相关的token(主语关注其动词)。有些实现"归纳"(模式补全)。有些广泛关注(收集全局上下文)。并非所有头同等重要——修剪20–40%的头通常对性能影响很小,这表明存在显著的冗余。

GQA和MQA

多查询注意力(MQA)使用一个跨所有查询头共享的键值头,将KV缓存大小减少了头的数量倍。分组查询注意力(GQA)是一种折中方案:多组查询头共享一个键值头(例如,32个查询头配8个KV头)。GQA在大幅减少KV缓存内存的同时保留了MHA的大部分质量。Llama 2 70B、Mistral和大多数现代LLM使用GQA。

相关概念

← 所有术语
← 基础模型 多智能体系统 →