--- id: wiki-2026-0508-grouped-query-attention-gqa title: Grouped-Query Attention (GQA) category: 10_Wiki/Topics status: verified canonical_id: self aliases: [GQA, grouped-query attention, MQA, multi-query attention, KV cache reduction, Llama] duplicate_of: none source_trust_level: A confidence_score: 0.96 verification_status: applied tags: [transformer, attention, gqa, mqa, kv-cache, llama, inference-optimization] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: language: Python framework: PyTorch / vLLM --- # Grouped-Query Attention (GQA) ## 매 한 줄 > **"매 multi-head attention 와 multi-query attention 의 가운데"**. Ainslie 2023 (Google). 매 Q heads = N, K/V heads = G (G < N). 매 KV cache size ↓ + 매 quality 의 MHA 와 가까움. 매 Llama 2 70B+, Mistral, 모든 modern LLM 의 standard. ## 매 핵심 ### 매 spectrum - **MHA**: Q=N, K=N, V=N (예: 32/32/32). - **MQA**: Q=N, K=1, V=1 (예: 32/1/1). - **GQA**: Q=N, K=G, V=G (예: 32/8/8). ### 매 trade-off - **MHA**: 매 best quality, 매 largest KV cache. - **MQA**: 매 smallest cache, 매 quality 매 ↓. - **GQA**: 매 sweet spot. ### 매 inference impact - **KV cache** = batch × seq_len × n_layers × n_kv_heads × head_dim × 2 (K, V). - 매 GQA: 매 N → G 의 의 의 cache 의 N/G 배 reduce. ### 매 응용 - **Llama 2 70B**: 32 Q heads, 8 KV heads. - **Llama 3**: GQA 표준. - **Mistral**, **Mixtral**: GQA. - **Gemma**, **Qwen**: GQA. ## 💻 패턴 ### MHA (baseline) ```python import torch import torch.nn.functional as F class MultiHeadAttention(torch.nn.Module): def __init__(self, dim, n_heads): super().__init__() self.n_heads = n_heads self.head_dim = dim // n_heads self.q_proj = torch.nn.Linear(dim, dim) self.k_proj = torch.nn.Linear(dim, dim) self.v_proj = torch.nn.Linear(dim, dim) self.o_proj = torch.nn.Linear(dim, dim) def forward(self, x): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.o_proj(out.transpose(1, 2).reshape(B, T, C)) ``` ### GQA ```python class GQA(torch.nn.Module): def __init__(self, dim, n_q_heads, n_kv_heads): super().__init__() self.n_q_heads = n_q_heads self.n_kv_heads = n_kv_heads self.n_rep = n_q_heads // n_kv_heads self.head_dim = dim // n_q_heads self.q_proj = torch.nn.Linear(dim, n_q_heads * self.head_dim) self.k_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim) self.v_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim) self.o_proj = torch.nn.Linear(dim, dim) def forward(self, x): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # 매 repeat KV heads to match Q k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.o_proj(out.transpose(1, 2).reshape(B, T, C)) ``` ### MQA (extreme) ```python class MQA(torch.nn.Module): """매 1 KV head.""" def __init__(self, dim, n_q_heads): super().__init__() self.n_q_heads = n_q_heads self.head_dim = dim // n_q_heads self.q_proj = torch.nn.Linear(dim, dim) self.k_proj = torch.nn.Linear(dim, self.head_dim) # 매 single head self.v_proj = torch.nn.Linear(dim, self.head_dim) self.o_proj = torch.nn.Linear(dim, dim) def forward(self, x): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1) v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.o_proj(out.transpose(1, 2).reshape(B, T, C)) ``` ### KV cache size (calculate) ```python def kv_cache_bytes(batch, seq_len, n_layers, n_kv_heads, head_dim, dtype_bytes=2): return batch * seq_len * n_layers * n_kv_heads * head_dim * 2 * dtype_bytes # 매 Llama 2 70B (MHA): 80 layers, 64 heads, 128 dim # 매 batch=1, seq=2048 # 매 cache = 1 * 2048 * 80 * 64 * 128 * 2 * 2 = 5.4 GB # 매 Llama 2 70B (GQA): 8 KV heads # 매 cache = 1 * 2048 * 80 * 8 * 128 * 2 * 2 = 670 MB (8x ↓) ``` ### Convert MHA → GQA (mean) ```python def mha_to_gqa(k_proj, v_proj, n_q_heads, n_kv_heads): """매 mean of N/G heads → single GQA head.""" n_rep = n_q_heads // n_kv_heads head_dim = k_proj.weight.size(0) // n_q_heads new_k_weight = k_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1) new_v_weight = v_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1) return new_k_weight, new_v_weight ``` ### Llama-style GQA + RoPE ```python class LlamaAttention(torch.nn.Module): def __init__(self, config): super().__init__() self.n_q = config.num_attention_heads self.n_kv = config.num_key_value_heads # 매 GQA self.head_dim = config.hidden_size // self.n_q self.q_proj = torch.nn.Linear(config.hidden_size, self.n_q * self.head_dim, bias=False) self.k_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False) self.v_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False) self.o_proj = torch.nn.Linear(self.n_q * self.head_dim, config.hidden_size, bias=False) self.rotary = RoPE(self.head_dim, config.max_position_embeddings) def forward(self, x, kv_cache=None): # 매 q, k, v + RoPE + cache ... ``` ### vLLM (production GQA serving) ```python from vllm import LLM llm = LLM(model='meta-llama/Llama-3.1-70B-Instruct') # 매 internally uses GQA + paged attention + flash attn ``` ### Flash Attention with GQA ```python from flash_attn import flash_attn_func # 매 supports GQA natively out = flash_attn_func(q, k, v, causal=True) # 매 q heads != kv heads ``` ### Benchmark KV cache savings ```python import time def benchmark(model, batch_sizes, seq_lens): for b in batch_sizes: for s in seq_lens: try: inputs = torch.randint(0, 32000, (b, s)).cuda() t0 = time.perf_counter() model.generate(inputs, max_new_tokens=128) print(f'b={b}, s={s}: {time.perf_counter()-t0:.2f}s, peak={torch.cuda.max_memory_allocated()/1e9:.2f}GB') except torch.cuda.OutOfMemoryError: print(f'b={b}, s={s}: OOM') torch.cuda.empty_cache() ``` ## 매 결정 기준 | 상황 | Approach | |---|---| | Long context | GQA (essential) | | High batch | GQA / MQA | | Quality-critical small | MHA | | Mobile / edge | MQA | | Llama-style | GQA (8 KV) | | Modern default | GQA | **기본값**: 매 modern LLM = GQA (4-8 KV heads). 매 quality > 1% loss = MHA. 매 extreme constraint = MQA. ## 🔗 Graph - 부모: [[Attention-Mechanism]] · [[Transformer]] - 변형: [[Multi-Head-Attention]] · [[Multi-Query-Attention]] - 응용: [[Llama]] · [[Flash Attention]] · [[LLM_Optimization_and_Deployment_Strategies|vLLM]] - Adjacent: [[KV-Cache]] · [[Paged-Attention]] · [[Foundation-Models]] ## 🤖 LLM 활용 **언제**: 매 모든 modern LLM. 매 long context. 매 high-batch serving. **언제 X**: 매 super-small model (< 1B). ## ❌ 안티패턴 - **MHA in production large LLM**: 매 KV cache OOM. - **Wrong N/G ratio**: 매 quality drop. - **No GQA-aware kernel**: 매 inefficient. - **Convert without retraining check**: 매 quality cliff. ## 🧪 검증 / 중복 - Verified (Ainslie GQA 2023, Llama 2/3 papers, Shazeer MQA 2019). - 신뢰도 A. ## 🕓 Changelog | 날짜 | 변경 | |---|---| | 2026-05-08 | Phase 1 | | 2026-05-10 | Manual cleanup — MHA/MQA/GQA + 매 KV cache calc / Llama / vLLM code |