Files
2nd/10_Wiki/Topics/AI_and_ML/Grouped-Query Attention (GQA).md
T
koriweb d8a80f6272 chore(wiki): dangling 링크 canonical 정규화 (768파일/1200건)
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해
끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은
과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업.
도구: Datacollect/scripts/link_reconcile_apply.mjs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 12:24:15 +09:00

237 lines
8.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
---
id: wiki-2026-0508-grouped-query-attention-gqa
title: Grouped-Query Attention (GQA)
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [GQA, grouped-query attention, MQA, multi-query attention, KV cache reduction, Llama]
duplicate_of: none
source_trust_level: A
confidence_score: 0.96
verification_status: applied
tags: [transformer, attention, gqa, mqa, kv-cache, llama, inference-optimization]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: Python
framework: PyTorch / vLLM
---
# Grouped-Query Attention (GQA)
## 매 한 줄
> **"매 multi-head attention 와 multi-query attention 의 가운데"**. Ainslie 2023 (Google). 매 Q heads = N, K/V heads = G (G < N). 매 KV cache size ↓ + 매 quality 의 MHA 와 가까움. 매 Llama 2 70B+, Mistral, 모든 modern LLM 의 standard.
## 매 핵심
### 매 spectrum
- **MHA**: Q=N, K=N, V=N (예: 32/32/32).
- **MQA**: Q=N, K=1, V=1 (예: 32/1/1).
- **GQA**: Q=N, K=G, V=G (예: 32/8/8).
### 매 trade-off
- **MHA**: 매 best quality, 매 largest KV cache.
- **MQA**: 매 smallest cache, 매 quality 매 ↓.
- **GQA**: 매 sweet spot.
### 매 inference impact
- **KV cache** = batch × seq_len × n_layers × n_kv_heads × head_dim × 2 (K, V).
- 매 GQA: 매 N → G 의 의 의 cache 의 N/G 배 reduce.
### 매 응용
- **Llama 2 70B**: 32 Q heads, 8 KV heads.
- **Llama 3**: GQA 표준.
- **Mistral**, **Mixtral**: GQA.
- **Gemma**, **Qwen**: GQA.
## 💻 패턴
### MHA (baseline)
```python
import torch
import torch.nn.functional as F
class MultiHeadAttention(torch.nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.q_proj = torch.nn.Linear(dim, dim)
self.k_proj = torch.nn.Linear(dim, dim)
self.v_proj = torch.nn.Linear(dim, dim)
self.o_proj = torch.nn.Linear(dim, dim)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
```
### GQA
```python
class GQA(torch.nn.Module):
def __init__(self, dim, n_q_heads, n_kv_heads):
super().__init__()
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_q_heads // n_kv_heads
self.head_dim = dim // n_q_heads
self.q_proj = torch.nn.Linear(dim, n_q_heads * self.head_dim)
self.k_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = torch.nn.Linear(dim, dim)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
# 매 repeat KV heads to match Q
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
```
### MQA (extreme)
```python
class MQA(torch.nn.Module):
"""매 1 KV head."""
def __init__(self, dim, n_q_heads):
super().__init__()
self.n_q_heads = n_q_heads
self.head_dim = dim // n_q_heads
self.q_proj = torch.nn.Linear(dim, dim)
self.k_proj = torch.nn.Linear(dim, self.head_dim) # 매 single head
self.v_proj = torch.nn.Linear(dim, self.head_dim)
self.o_proj = torch.nn.Linear(dim, dim)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
```
### KV cache size (calculate)
```python
def kv_cache_bytes(batch, seq_len, n_layers, n_kv_heads, head_dim, dtype_bytes=2):
return batch * seq_len * n_layers * n_kv_heads * head_dim * 2 * dtype_bytes
# 매 Llama 2 70B (MHA): 80 layers, 64 heads, 128 dim
# 매 batch=1, seq=2048
# 매 cache = 1 * 2048 * 80 * 64 * 128 * 2 * 2 = 5.4 GB
# 매 Llama 2 70B (GQA): 8 KV heads
# 매 cache = 1 * 2048 * 80 * 8 * 128 * 2 * 2 = 670 MB (8x ↓)
```
### Convert MHA → GQA (mean)
```python
def mha_to_gqa(k_proj, v_proj, n_q_heads, n_kv_heads):
"""매 mean of N/G heads → single GQA head."""
n_rep = n_q_heads // n_kv_heads
head_dim = k_proj.weight.size(0) // n_q_heads
new_k_weight = k_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
new_v_weight = v_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
return new_k_weight, new_v_weight
```
### Llama-style GQA + RoPE
```python
class LlamaAttention(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.n_q = config.num_attention_heads
self.n_kv = config.num_key_value_heads # 매 GQA
self.head_dim = config.hidden_size // self.n_q
self.q_proj = torch.nn.Linear(config.hidden_size, self.n_q * self.head_dim, bias=False)
self.k_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
self.v_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
self.o_proj = torch.nn.Linear(self.n_q * self.head_dim, config.hidden_size, bias=False)
self.rotary = RoPE(self.head_dim, config.max_position_embeddings)
def forward(self, x, kv_cache=None):
# 매 q, k, v + RoPE + cache
...
```
### vLLM (production GQA serving)
```python
from vllm import LLM
llm = LLM(model='meta-llama/Llama-3.1-70B-Instruct')
# 매 internally uses GQA + paged attention + flash attn
```
### Flash Attention with GQA
```python
from flash_attn import flash_attn_func
# 매 supports GQA natively
out = flash_attn_func(q, k, v, causal=True) # 매 q heads != kv heads
```
### Benchmark KV cache savings
```python
import time
def benchmark(model, batch_sizes, seq_lens):
for b in batch_sizes:
for s in seq_lens:
try:
inputs = torch.randint(0, 32000, (b, s)).cuda()
t0 = time.perf_counter()
model.generate(inputs, max_new_tokens=128)
print(f'b={b}, s={s}: {time.perf_counter()-t0:.2f}s, peak={torch.cuda.max_memory_allocated()/1e9:.2f}GB')
except torch.cuda.OutOfMemoryError:
print(f'b={b}, s={s}: OOM')
torch.cuda.empty_cache()
```
## 매 결정 기준
| 상황 | Approach |
|---|---|
| Long context | GQA (essential) |
| High batch | GQA / MQA |
| Quality-critical small | MHA |
| Mobile / edge | MQA |
| Llama-style | GQA (8 KV) |
| Modern default | GQA |
**기본값**: 매 modern LLM = GQA (4-8 KV heads). 매 quality > 1% loss = MHA. 매 extreme constraint = MQA.
## 🔗 Graph
- 부모: [[Attention Mechanism]] · [[Transformer]]
- 변형: [[Multi-Head-Attention]] · [[Multi-Query-Attention]]
- 응용: [[Llama]] · [[Flash Attention]] · [[LLM_Optimization_and_Deployment_Strategies|vLLM]]
- Adjacent: [[KV-Cache]] · [[PagedAttention]] · [[Foundation-Models]]
## 🤖 LLM 활용
**언제**: 매 모든 modern LLM. 매 long context. 매 high-batch serving.
**언제 X**: 매 super-small model (< 1B).
## ❌ 안티패턴
- **MHA in production large LLM**: 매 KV cache OOM.
- **Wrong N/G ratio**: 매 quality drop.
- **No GQA-aware kernel**: 매 inefficient.
- **Convert without retraining check**: 매 quality cliff.
## 🧪 검증 / 중복
- Verified (Ainslie GQA 2023, Llama 2/3 papers, Shazeer MQA 2019).
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — MHA/MQA/GQA + 매 KV cache calc / Llama / vLLM code |