大模型入门:从MHA到GQA,KV Cache显存优化指南

2026-06-02阅读 0热度 0
大模型

大模型入门:从 MHA 到 GQA,拆解 KV Cache 显存优化的核心逻辑

初次在本地部署大模型时,很多人直观认为显存主要由模型参数占据——这没错,一个 7B 模型即使采用 FP16 精度,参数显存也在十多个 GB 量级。

但进入实际推理阶段,另一个因素的增速远超预期。prompt 长度越长,KV Cache 膨胀越快;batch size 越大,KV Cache 越庞大;上下文窗口越宽,KV Cache 越难以控制;并发请求越高,KV Cache 的管理复杂度也急剧上升。

模型参数在加载后基本固定,而 KV Cache 在生成过程中随请求、序列长度和 batch 持续动态增长。这也是为何服务端推理框架投入大量精力优化 KV Cache:vLLM 的 PagedAttention、Hugging Face 的 DynamicCache/StaticCache/QuantizedCache,本质上都在解决同一个矛盾——既要快速读取历史 K/V,又不能撑爆显存。

GQA 恰好处于这个问题的中心位置。

一句话概括:GQA 通过削减 KV Head 的数量,从源头压缩 KV Cache 的体积。

1. 先回顾:KV Cache 到底缓存了什么

Decoder-only 大模型推理通常分为两个阶段:

阶段输入主要动作
Prefill完整 prompt一次性计算 prompt 每层的 K/V,并写入 cache
Decode当前新 token只计算新 token 的 Q/K/V,用新 Q 查询历史 K/V

Hugging Face 的缓存文档明确指出:自回归生成是逐 token 向后预测,KV Cache 保存了之前 token 在注意力层中的 K/V,后续 token 可直接复用,避免重复计算。

上一篇文章中,我们使用的 MHA 张量形状为:

q.shape == [batch, num_heads, seq_len, head_dim]
k.shape == [batch, num_heads, seq_len, head_dim]
v.shape == [batch, num_heads, seq_len, head_dim]

每一层需要缓存历史 token 的 kv

past_k.shape == [batch, num_heads, past_len, head_dim]
past_v.shape == [batch, num_heads, past_len, head_dim]

注意,这里是每一层独立缓存。一个 32 层模型就有 32 份缓存。因此 KV Cache 的显存估算公式为:

KV Cache bytes = batch_size * seq_len * num_layers * 2 * num_kv_heads * head_dim * bytes_per_element

公式中的 2 代表 K 和 V 两份。最容易忽视的是 num_kv_heads 这一项。

在 MHA 中:num_kv_heads = num_query_heads。而在 GQA 中:num_kv_heads < num_query_heads。这正是 GQA 节省显存的入口。

2. 用具体数字算清楚

假设一个简化配置:

batch_size = 1
seq_len = 8192
num_layers = 32
num_query_heads = 32
head_dim = 128
dtype = fp16  # 2 bytes

传统 MHA 下:num_kv_heads = 32,KV Cache 约为:

1 * 8192 * 32 * 2 * 32 * 128 * 2 bytes = 4 GiB

若改用 GQA,设 num_kv_heads = 8,KV Cache 约为:

1 * 8192 * 32 * 2 * 8 * 128 * 2 bytes = 1 GiB

同样的 Query Head 数量和上下文长度,仅将 KV Head 从 32 降至 8,缓存压缩为原来的四分之一。

如果是 MQA:num_kv_heads = 1,KV Cache 进一步降到:

128 MiB

这只是一个教学演示,实际框架还会受 allocator、block size、padding、并发调度、量化及 kernel 实现等因素影响。但作为面试和工程理解,这个公式足以抓住核心。

3. MHA、MQA、GQA 的区别

先用一张表快速区分:

结构Query HeadKV Head直觉
MHA多个和 Query 一样多每个 Q head 独享一组 K/V
MQA多个1 个所有 Q head 共享同一组 K/V
GQA多个介于 1 和 Query Head 之间一组 Q head 共享一组 K/V

设:num_query_heads = 32num_kv_heads = 8,则 group_size = num_query_heads // num_kv_heads = 4

GQA 的含义是:前 4 个 Q head(0123)共享一个 KV Head,接下来 4 个(4567)共享下一个,依此类推。它不像 MQA 将所有 Query Head 压到同一个 KV Head 上,也不像 MHA 为每个 Query Head 保留独立 K/V。

GQA 原论文的动机也在于此:MQA 能显著加速 decoder 推理,但可能牺牲质量;GQA 使用介于 1 和 Query Head 数之间的 KV Head 数量,在效果与推理效率之间寻求平衡。

4. 张量形状如何变化

MHA 的投影通常为:

q_proj: hidden_dim -> num_q_heads * head_dim
k_proj: hidden_dim -> num_q_heads * head_dim
v_proj: hidden_dim -> num_q_heads * head_dim

GQA 的投影变为:

q_proj: hidden_dim -> num_q_heads * head_dim
k_proj: hidden_dim -> num_kv_heads * head_dim
v_proj: hidden_dim -> num_kv_heads * head_dim

也就是说,Q 仍保持多头,K/V 减少。

假设:batch = 2seq_len = 5num_q_heads = 32num_kv_heads = 8head_dim = 128,则:

q.shape == [2, 32, 5, 128]
k.shape == [2, 8, 5, 128]
v.shape == [2, 8, 5, 128]

但在 attention 计算中,q @ k.transpose(-2, -1) 要求 head 维度对齐。一个教学做法是将 K/V 按组展开:

k_expanded.shape == [2, 32, 5, 128]
v_expanded.shape == [2, 32, 5, 128]

PyTorch 的 scaled_dot_product_attention(enable_gqa=True) 文档展示了类似逻辑:启用 GQA 时,会根据 Query Head 和 KV Head 的比例对 key/value 做 repeat_interleave。但实际高性能实现不一定物理复制 K/V,服务端推理更关注 cache 布局、访存模式和 kernel 实现。

5. 手写一个最小 GQA

以下代码保留核心逻辑,适合面试讲解:

  • Q Head 数可以大于 KV Head 数;
  • KV Head 必须能整除 Query Head;
  • K/V 先按较少 head 存储;
  • 计算 attention 前按组展开;
  • cache 里只缓存较少的 KV Head。
import math
import torch
from torch import nn

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    # x: [B, H_kv, T, D]
    if n_rep == 1:
        return x
    batch, num_kv_heads, seq_len, head_dim = x.shape
    x = x[:, :, None, :, :]
    x = x.expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
    return x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)

class GroupedQueryAttention(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 num_q_heads: int,
                 num_kv_heads: int,
                 dropout: float = 0.0,
                 ):
        super().__init__()
        assert hidden_dim % num_q_heads == 0
        assert num_q_heads % num_kv_heads == 0
        self.hidden_dim = hidden_dim
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = hidden_dim // num_q_heads
        self.num_groups = num_q_heads // num_kv_heads

        self.q_proj = nn.Linear(hidden_dim, num_q_heads * self.head_dim)
        self.k_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_kv_heads * self.head_dim)
        self.o_proj = nn.Linear(num_q_heads * self.head_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, num_heads, self.head_dim)
        return x.transpose(1, 2)  # [B, H, T, D]

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        batch, heads, seq_len, head_dim = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, heads * head_dim)

    def forward(self,
                x: torch.Tensor,
                attn_mask: torch.Tensor | None = None,
                past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
                use_cache: bool = False,
                ):
        q = self._split_heads(self.q_proj(x), self.num_q_heads)
        k = self._split_heads(self.k_proj(x), self.num_kv_heads)
        v = self._split_heads(self.v_proj(x), self.num_kv_heads)

        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        present_key_value = (k, v) if use_cache else None

        k_for_attn = repeat_kv(k, self.num_groups)
        v_for_attn = repeat_kv(v, self.num_groups)

        scores = q @ k_for_attn.transpose(-2, -1)
        scores = scores / math.sqrt(self.head_dim)
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float("-inf"))
        weights = torch.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        out = weights @ v_for_attn
        out = self._merge_heads(out)
        out = self.o_proj(out)
        return out, weights, present_key_value

测试形状:

x = torch.randn(2, 5, 4096)
gqa = GroupedQueryAttention(
    hidden_dim=4096,
    num_q_heads=32,
    num_kv_heads=8,
)
out, weights, cache = gqa(x, use_cache=True)
print(out.shape)       # [2, 5, 4096]
print(weights.shape)   # [2, 32, 5, 5]
print(cache[0].shape)  # [2, 8, 5, 128]
print(cache[1].shape)  # [2, 8, 5, 128]

关键在最后两行:注意力权重仍是 32 个 Query Head —— weights.shape == [2, 32, 5, 5],但缓存里只保留 8 个 KV Head —— cache[0].shape == [2, 8, 5, 128]cache[1].shape == [2, 8, 5, 128]。这就是 GQA 在 KV Cache 上省显存的直接体现。

6. 用 PyTorch 接口怎么写

PyTorch 的 torch.nn.functional.scaled_dot_product_attention 已提供 enable_gqa 参数。

一个最小示例:

import torch
import torch.nn.functional as F

query = torch.randn(2, 32, 5, 128, device="cuda", dtype=torch.float16)
key = torch.randn(2, 8, 5, 128, device="cuda", dtype=torch.float16)
value = torch.randn(2, 8, 5, 128, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(
    query, key, value,
    is_causal=True,
    enable_gqa=True,
)
print(out.shape)  # [2, 32, 5, 128]

官方文档中两个关键约束:

number_of_heads_query % number_of_heads_key_value == 0
number_of_heads_key == number_of_heads_value

即:

  • Query Head 数必须能被 KV Head 数整除;
  • Key Head 数和 Value Head 数必须相同;
  • enable_gqa 目前仍是实验特性,后端支持和张量类型有限制。

另有一个易踩坑的点:PyTorch 该函数中布尔 attn_mask 的语义与其他 MHA 接口的 padding mask 相反。scaled_dot_product_attentionTrue 表示参与 attention,迁移代码时需仔细检查。

7. 为什么 GQA 主要影响推理

如果只做一次完整 forward 且不使用 KV Cache,GQA 对峰值显存的影响不如 KV Cache 场景明显。

真正的收益集中在自回归 decode:

每一步都要读历史 K/V
历史越长,读取量越大
并发越高,cache 越胀
KV Head 越少,cache 越小

Hugging Face 的优化文档也指出,减少 KV 向量数量只有在使用 KV Cache 的自回归解码场景中才特别有意义,因为 decode 阶段会反复读取历史 K/V,内存带宽极易成为瓶颈。

总结如下:

场景GQA 价值
训练全序列并行不是主要优化目标
Prefill可减少写入 cache 的 K/V 体积
Decode最关键,减少每步读取的历史 K/V
长上下文服务价值更突出
高并发服务价值更突出

因此讲 GQA 时,不能只罗列 attention 公式,必须将其放回推理服务的 KV Cache 场景中考察。

8. 与 vLLM、PagedAttention 的关系

GQA 解决的是:单个 token 的 K/V 体积更小。PagedAttention 解决的是:大量 token 的 K/V 如何高效组织和管理。两者不属于同一层优化,但共同影响推理效率。

vLLM 的 PagedAttention 文档提到,key/value cache 会被拆成 block,每个 block 存储固定数量 token 的 cache。这样做的目的是用更适合服务端调度的方式管理 KV Cache,而非将每个请求视为一大段连续显存。

可以放在同一张图中理解:

GQA:减少每个 token 的 KV 体积
PagedAttention:管理许多 token 的 KV 存放方式
Quantized Cache:降低每个元素的字节数
Offloaded Cache:将部分 cache 放到 CPU

若只看单次模型结构,GQA 像是 attention 结构变化;若从推理系统看,GQA 是 KV Cache 成本控制的一环。

9. 常见坑

坑 1:只改 num_kv_heads,忘了改投影层输出维度

GQA 中 Q/K/V 的 projection 输出维度不同:

q_proj -> num_q_heads * head_dim
k_proj -> num_kv_heads * head_dim
v_proj -> num_kv_heads * head_dim

如果仍将 K/V 投影到 num_q_heads * head_dim,cache 就省不下来。

坑 2:num_q_heads 不能整除 num_kv_heads

GQA 按组共享 K/V,因此通常要求 num_q_heads % num_kv_heads == 0,否则每组 Query Head 无法均匀映射到 KV Head。

坑 3:把 repeat 后的 K/V 当成 cache 存

教学代码为便于理解,会在 attention 前做 repeat_kv。但 cache 应该保留较少的 KV Head:cache_k.shape == [B, H_kv, T, D]。若把展开后的 K/V 存进去:cache_k.shape == [B, H_q, T, D],显存又回到 MHA 水平。

坑 4:只算 cache 容量,不看内存带宽

KV Cache 不只占显存。Decode 每一步都要读取历史 K/V,所以内存带宽也会成为瓶颈。GQA 的价值不仅是少存,也包括少读。

坑 5:把 GQA 当成无损替换

GQA 是效果与效率的折中。原论文结论是 GQA 相比 MQA 更能保留 MHA 的质量,同时接近 MQA 的速度收益。但具体效果仍取决于模型、训练方式、上采样策略和任务。工程上不应将结构变化视为“免费优化”——它通常在模型设计或训练阶段就已确定。

10. 面试怎么讲

如果面试官问:“GQA 和 MHA 有什么区别?”

可以回答:GQA 的核心差异在于 Query Head 数量多于 Key/Value Head 数量,多个 Query Head 共享同一组 K/V。而 MHA 中每个 Query Head 都有独立的 K/V。

如果继续问:“为什么能省显存?”

可以接:因为 KV Cache 的大小与 num_kv_heads 直接成正比。在相同的 Query Head 数量和序列长度下,GQA 只需缓存更少的 K/V Head,因此显存占用更小。

如果问:“GQA、MQA 怎么区分?”

可以答:MQA 是所有 Query Head 共享一个 KV Head,极端省显存但可能损失效果;GQA 是折中方案,将 Query Head 分成若干组,每组共享一个 KV Head。

如果问:“代码里最容易错在哪里?”

可以答:最容易错的是投影层的输出维度改错,以及 cache 中意外存储展开后的 K/V。核心约束是 num_q_heads % num_kv_heads == 0

11. 一张速记表

问题关键回答
GQA 改了什么?Query Head 多,KV Head 少
为什么能省显存?KV Cache 大小和 num_kv_heads 成正比
MHA 的 KV Head 数?通常等于 Query Head 数
MQA 的 KV Head 数?1 个
GQA 的 KV Head 数?介于 1 和 Query Head 数之间
代码核心约束?num_q_heads % num_kv_heads == 0
cache 里存什么?未展开的 K/V,形状是 [B, H_kv, T, D]
attention 前做什么?把 K/V 按组映射到 Query Head
最适合讲的场景?长上下文、自回归 decode、高并发推理
PyTorch 接口?scaled_dot_product_attention(..., enable_gqa=True)

总结

GQA 可以用三句话概括:

  1. MHA 中每个 Query Head 通常拥有自己的 K/V,KV Cache 随 Query Head 数增长。
  2. GQA 让一组 Query Head 共享较少的 K/V Head,KV Cache 随 KV Head 数增长。
  3. 其主要价值出现在自回归推理,尤其是长上下文和高并发服务中。

所以,学习 GQA 不能只记住一个缩写。真正需要记住的是这条线索:

MHA 张量形状 -> KV Cache 显存公式 -> KV Head 数量 -> Decode 访存压力 -> GQA

这条线理清了,GQA、MQA、KV Cache、长上下文推理优化就能形成体系。

参考资料

  • Joshua Ainslie et al.:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
    arxiv.org/abs/2305.13…
  • PyTorch:torch.nn.functional.scaled_dot_product_attention
    docs.pytorch.org/docs/main/g…
  • Hugging Face Transformers:Caching
    huggingface.co/docs/transf…
  • Hugging Face Transformers:KV cache strategies
    huggingface.co/docs/transf…
  • Hugging Face Transformers:Optimizing LLMs for Speed and Memory
    huggingface.co/docs/transf…
  • vLLM:Paged Attention
    docs.vllm.ai/en/latest/d…
免责声明

本网站新闻资讯均来自公开渠道,力求准确但不保证绝对无误,内容观点仅代表作者本人,与本站无关。若涉及侵权,请联系我们处理。本站保留对声明的修改权,最终解释权归本站所有。

相关阅读

更多
欢迎回来 登录或注册后,可保存提示词和历史记录
登录后可同步收藏、历史记录和常用模板
注册即表示同意服务条款与隐私政策