d8a80f6272
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해 끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은 과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업. 도구: Datacollect/scripts/link_reconcile_apply.mjs Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
8.4 KiB
8.4 KiB
id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
| id | title | category | status | canonical_id | aliases | duplicate_of | source_trust_level | confidence_score | verification_status | tags | raw_sources | last_reinforced | github_commit | tech_stack | |||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| wiki-2026-0508-grouped-query-attention-gqa | Grouped-Query Attention (GQA) | 10_Wiki/Topics | verified | self |
|
none | A | 0.96 | applied |
|
2026-05-10 | pending |
|
Grouped-Query Attention (GQA)
매 한 줄
"매 multi-head attention 와 multi-query attention 의 가운데". Ainslie 2023 (Google). 매 Q heads = N, K/V heads = G (G < N). 매 KV cache size ↓ + 매 quality 의 MHA 와 가까움. 매 Llama 2 70B+, Mistral, 모든 modern LLM 의 standard.
매 핵심
매 spectrum
- MHA: Q=N, K=N, V=N (예: 32/32/32).
- MQA: Q=N, K=1, V=1 (예: 32/1/1).
- GQA: Q=N, K=G, V=G (예: 32/8/8).
매 trade-off
- MHA: 매 best quality, 매 largest KV cache.
- MQA: 매 smallest cache, 매 quality 매 ↓.
- GQA: 매 sweet spot.
매 inference impact
- KV cache = batch × seq_len × n_layers × n_kv_heads × head_dim × 2 (K, V).
- 매 GQA: 매 N → G 의 의 의 cache 의 N/G 배 reduce.
매 응용
- Llama 2 70B: 32 Q heads, 8 KV heads.
- Llama 3: GQA 표준.
- Mistral, Mixtral: GQA.
- Gemma, Qwen: GQA.
💻 패턴
MHA (baseline)
import torch
import torch.nn.functional as F
class MultiHeadAttention(torch.nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.q_proj = torch.nn.Linear(dim, dim)
self.k_proj = torch.nn.Linear(dim, dim)
self.v_proj = torch.nn.Linear(dim, dim)
self.o_proj = torch.nn.Linear(dim, dim)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
GQA
class GQA(torch.nn.Module):
def __init__(self, dim, n_q_heads, n_kv_heads):
super().__init__()
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_q_heads // n_kv_heads
self.head_dim = dim // n_q_heads
self.q_proj = torch.nn.Linear(dim, n_q_heads * self.head_dim)
self.k_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = torch.nn.Linear(dim, dim)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
# 매 repeat KV heads to match Q
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
MQA (extreme)
class MQA(torch.nn.Module):
"""매 1 KV head."""
def __init__(self, dim, n_q_heads):
super().__init__()
self.n_q_heads = n_q_heads
self.head_dim = dim // n_q_heads
self.q_proj = torch.nn.Linear(dim, dim)
self.k_proj = torch.nn.Linear(dim, self.head_dim) # 매 single head
self.v_proj = torch.nn.Linear(dim, self.head_dim)
self.o_proj = torch.nn.Linear(dim, dim)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
KV cache size (calculate)
def kv_cache_bytes(batch, seq_len, n_layers, n_kv_heads, head_dim, dtype_bytes=2):
return batch * seq_len * n_layers * n_kv_heads * head_dim * 2 * dtype_bytes
# 매 Llama 2 70B (MHA): 80 layers, 64 heads, 128 dim
# 매 batch=1, seq=2048
# 매 cache = 1 * 2048 * 80 * 64 * 128 * 2 * 2 = 5.4 GB
# 매 Llama 2 70B (GQA): 8 KV heads
# 매 cache = 1 * 2048 * 80 * 8 * 128 * 2 * 2 = 670 MB (8x ↓)
Convert MHA → GQA (mean)
def mha_to_gqa(k_proj, v_proj, n_q_heads, n_kv_heads):
"""매 mean of N/G heads → single GQA head."""
n_rep = n_q_heads // n_kv_heads
head_dim = k_proj.weight.size(0) // n_q_heads
new_k_weight = k_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
new_v_weight = v_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
return new_k_weight, new_v_weight
Llama-style GQA + RoPE
class LlamaAttention(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.n_q = config.num_attention_heads
self.n_kv = config.num_key_value_heads # 매 GQA
self.head_dim = config.hidden_size // self.n_q
self.q_proj = torch.nn.Linear(config.hidden_size, self.n_q * self.head_dim, bias=False)
self.k_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
self.v_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
self.o_proj = torch.nn.Linear(self.n_q * self.head_dim, config.hidden_size, bias=False)
self.rotary = RoPE(self.head_dim, config.max_position_embeddings)
def forward(self, x, kv_cache=None):
# 매 q, k, v + RoPE + cache
...
vLLM (production GQA serving)
from vllm import LLM
llm = LLM(model='meta-llama/Llama-3.1-70B-Instruct')
# 매 internally uses GQA + paged attention + flash attn
Flash Attention with GQA
from flash_attn import flash_attn_func
# 매 supports GQA natively
out = flash_attn_func(q, k, v, causal=True) # 매 q heads != kv heads
Benchmark KV cache savings
import time
def benchmark(model, batch_sizes, seq_lens):
for b in batch_sizes:
for s in seq_lens:
try:
inputs = torch.randint(0, 32000, (b, s)).cuda()
t0 = time.perf_counter()
model.generate(inputs, max_new_tokens=128)
print(f'b={b}, s={s}: {time.perf_counter()-t0:.2f}s, peak={torch.cuda.max_memory_allocated()/1e9:.2f}GB')
except torch.cuda.OutOfMemoryError:
print(f'b={b}, s={s}: OOM')
torch.cuda.empty_cache()
매 결정 기준
| 상황 | Approach |
|---|---|
| Long context | GQA (essential) |
| High batch | GQA / MQA |
| Quality-critical small | MHA |
| Mobile / edge | MQA |
| Llama-style | GQA (8 KV) |
| Modern default | GQA |
기본값: 매 modern LLM = GQA (4-8 KV heads). 매 quality > 1% loss = MHA. 매 extreme constraint = MQA.
🔗 Graph
- 부모: Attention Mechanism · Transformer
- 변형: Multi-Head-Attention · Multi-Query-Attention
- 응용: Llama · Flash Attention · LLM_Optimization_and_Deployment_Strategies
- Adjacent: KV-Cache · PagedAttention · Foundation-Models
🤖 LLM 활용
언제: 매 모든 modern LLM. 매 long context. 매 high-batch serving. 언제 X: 매 super-small model (< 1B).
❌ 안티패턴
- MHA in production large LLM: 매 KV cache OOM.
- Wrong N/G ratio: 매 quality drop.
- No GQA-aware kernel: 매 inefficient.
- Convert without retraining check: 매 quality cliff.
🧪 검증 / 중복
- Verified (Ainslie GQA 2023, Llama 2/3 papers, Shazeer MQA 2019).
- 신뢰도 A.
🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — MHA/MQA/GQA + 매 KV cache calc / Llama / vLLM code |