f8b21af4be
10_Wiki/Topics 대규모 정리: - 오류 캡처/미완성 stub 문서 227개 제거 - 교차폴더 중복 43클러스터 병합 (63파일 → redirect) - 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건 - 카테고리 MOC 6개 신규 생성 - Graph 섹션 미해결 related-keyword 링크 10,058건 제거 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
237 lines
8.4 KiB
Markdown
237 lines
8.4 KiB
Markdown
---
|
||
id: wiki-2026-0508-grouped-query-attention-gqa
|
||
title: Grouped-Query Attention (GQA)
|
||
category: 10_Wiki/Topics
|
||
status: verified
|
||
canonical_id: self
|
||
aliases: [GQA, grouped-query attention, MQA, multi-query attention, KV cache reduction, Llama]
|
||
duplicate_of: none
|
||
source_trust_level: A
|
||
confidence_score: 0.96
|
||
verification_status: applied
|
||
tags: [transformer, attention, gqa, mqa, kv-cache, llama, inference-optimization]
|
||
raw_sources: []
|
||
last_reinforced: 2026-05-10
|
||
github_commit: pending
|
||
tech_stack:
|
||
language: Python
|
||
framework: PyTorch / vLLM
|
||
---
|
||
|
||
# Grouped-Query Attention (GQA)
|
||
|
||
## 매 한 줄
|
||
> **"매 multi-head attention 와 multi-query attention 의 가운데"**. Ainslie 2023 (Google). 매 Q heads = N, K/V heads = G (G < N). 매 KV cache size ↓ + 매 quality 의 MHA 와 가까움. 매 Llama 2 70B+, Mistral, 모든 modern LLM 의 standard.
|
||
|
||
## 매 핵심
|
||
|
||
### 매 spectrum
|
||
- **MHA**: Q=N, K=N, V=N (예: 32/32/32).
|
||
- **MQA**: Q=N, K=1, V=1 (예: 32/1/1).
|
||
- **GQA**: Q=N, K=G, V=G (예: 32/8/8).
|
||
|
||
### 매 trade-off
|
||
- **MHA**: 매 best quality, 매 largest KV cache.
|
||
- **MQA**: 매 smallest cache, 매 quality 매 ↓.
|
||
- **GQA**: 매 sweet spot.
|
||
|
||
### 매 inference impact
|
||
- **KV cache** = batch × seq_len × n_layers × n_kv_heads × head_dim × 2 (K, V).
|
||
- 매 GQA: 매 N → G 의 의 의 cache 의 N/G 배 reduce.
|
||
|
||
### 매 응용
|
||
- **Llama 2 70B**: 32 Q heads, 8 KV heads.
|
||
- **Llama 3**: GQA 표준.
|
||
- **Mistral**, **Mixtral**: GQA.
|
||
- **Gemma**, **Qwen**: GQA.
|
||
|
||
## 💻 패턴
|
||
|
||
### MHA (baseline)
|
||
```python
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
class MultiHeadAttention(torch.nn.Module):
|
||
def __init__(self, dim, n_heads):
|
||
super().__init__()
|
||
self.n_heads = n_heads
|
||
self.head_dim = dim // n_heads
|
||
self.q_proj = torch.nn.Linear(dim, dim)
|
||
self.k_proj = torch.nn.Linear(dim, dim)
|
||
self.v_proj = torch.nn.Linear(dim, dim)
|
||
self.o_proj = torch.nn.Linear(dim, dim)
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.shape
|
||
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
|
||
```
|
||
|
||
### GQA
|
||
```python
|
||
class GQA(torch.nn.Module):
|
||
def __init__(self, dim, n_q_heads, n_kv_heads):
|
||
super().__init__()
|
||
self.n_q_heads = n_q_heads
|
||
self.n_kv_heads = n_kv_heads
|
||
self.n_rep = n_q_heads // n_kv_heads
|
||
self.head_dim = dim // n_q_heads
|
||
|
||
self.q_proj = torch.nn.Linear(dim, n_q_heads * self.head_dim)
|
||
self.k_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
|
||
self.v_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
|
||
self.o_proj = torch.nn.Linear(dim, dim)
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.shape
|
||
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
|
||
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
||
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
||
|
||
# 매 repeat KV heads to match Q
|
||
k = k.repeat_interleave(self.n_rep, dim=1)
|
||
v = v.repeat_interleave(self.n_rep, dim=1)
|
||
|
||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
|
||
```
|
||
|
||
### MQA (extreme)
|
||
```python
|
||
class MQA(torch.nn.Module):
|
||
"""매 1 KV head."""
|
||
def __init__(self, dim, n_q_heads):
|
||
super().__init__()
|
||
self.n_q_heads = n_q_heads
|
||
self.head_dim = dim // n_q_heads
|
||
self.q_proj = torch.nn.Linear(dim, dim)
|
||
self.k_proj = torch.nn.Linear(dim, self.head_dim) # 매 single head
|
||
self.v_proj = torch.nn.Linear(dim, self.head_dim)
|
||
self.o_proj = torch.nn.Linear(dim, dim)
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.shape
|
||
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
|
||
k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
|
||
v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
|
||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
|
||
```
|
||
|
||
### KV cache size (calculate)
|
||
```python
|
||
def kv_cache_bytes(batch, seq_len, n_layers, n_kv_heads, head_dim, dtype_bytes=2):
|
||
return batch * seq_len * n_layers * n_kv_heads * head_dim * 2 * dtype_bytes
|
||
|
||
# 매 Llama 2 70B (MHA): 80 layers, 64 heads, 128 dim
|
||
# 매 batch=1, seq=2048
|
||
# 매 cache = 1 * 2048 * 80 * 64 * 128 * 2 * 2 = 5.4 GB
|
||
|
||
# 매 Llama 2 70B (GQA): 8 KV heads
|
||
# 매 cache = 1 * 2048 * 80 * 8 * 128 * 2 * 2 = 670 MB (8x ↓)
|
||
```
|
||
|
||
### Convert MHA → GQA (mean)
|
||
```python
|
||
def mha_to_gqa(k_proj, v_proj, n_q_heads, n_kv_heads):
|
||
"""매 mean of N/G heads → single GQA head."""
|
||
n_rep = n_q_heads // n_kv_heads
|
||
head_dim = k_proj.weight.size(0) // n_q_heads
|
||
|
||
new_k_weight = k_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
|
||
new_v_weight = v_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
|
||
|
||
return new_k_weight, new_v_weight
|
||
```
|
||
|
||
### Llama-style GQA + RoPE
|
||
```python
|
||
class LlamaAttention(torch.nn.Module):
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.n_q = config.num_attention_heads
|
||
self.n_kv = config.num_key_value_heads # 매 GQA
|
||
self.head_dim = config.hidden_size // self.n_q
|
||
self.q_proj = torch.nn.Linear(config.hidden_size, self.n_q * self.head_dim, bias=False)
|
||
self.k_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
|
||
self.v_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
|
||
self.o_proj = torch.nn.Linear(self.n_q * self.head_dim, config.hidden_size, bias=False)
|
||
self.rotary = RoPE(self.head_dim, config.max_position_embeddings)
|
||
|
||
def forward(self, x, kv_cache=None):
|
||
# 매 q, k, v + RoPE + cache
|
||
...
|
||
```
|
||
|
||
### vLLM (production GQA serving)
|
||
```python
|
||
from vllm import LLM
|
||
llm = LLM(model='meta-llama/Llama-3.1-70B-Instruct')
|
||
# 매 internally uses GQA + paged attention + flash attn
|
||
```
|
||
|
||
### Flash Attention with GQA
|
||
```python
|
||
from flash_attn import flash_attn_func
|
||
# 매 supports GQA natively
|
||
out = flash_attn_func(q, k, v, causal=True) # 매 q heads != kv heads
|
||
```
|
||
|
||
### Benchmark KV cache savings
|
||
```python
|
||
import time
|
||
def benchmark(model, batch_sizes, seq_lens):
|
||
for b in batch_sizes:
|
||
for s in seq_lens:
|
||
try:
|
||
inputs = torch.randint(0, 32000, (b, s)).cuda()
|
||
t0 = time.perf_counter()
|
||
model.generate(inputs, max_new_tokens=128)
|
||
print(f'b={b}, s={s}: {time.perf_counter()-t0:.2f}s, peak={torch.cuda.max_memory_allocated()/1e9:.2f}GB')
|
||
except torch.cuda.OutOfMemoryError:
|
||
print(f'b={b}, s={s}: OOM')
|
||
torch.cuda.empty_cache()
|
||
```
|
||
|
||
## 매 결정 기준
|
||
| 상황 | Approach |
|
||
|---|---|
|
||
| Long context | GQA (essential) |
|
||
| High batch | GQA / MQA |
|
||
| Quality-critical small | MHA |
|
||
| Mobile / edge | MQA |
|
||
| Llama-style | GQA (8 KV) |
|
||
| Modern default | GQA |
|
||
|
||
**기본값**: 매 modern LLM = GQA (4-8 KV heads). 매 quality > 1% loss = MHA. 매 extreme constraint = MQA.
|
||
|
||
## 🔗 Graph
|
||
- 부모: [[Attention-Mechanism]] · [[Transformer]]
|
||
- 변형: [[Multi-Head-Attention]] · [[Multi-Query-Attention]]
|
||
- 응용: [[Llama]] · [[Flash Attention]] · [[LLM_Optimization_and_Deployment_Strategies|vLLM]]
|
||
- Adjacent: [[KV-Cache]] · [[Paged-Attention]] · [[Foundation-Models]]
|
||
|
||
## 🤖 LLM 활용
|
||
**언제**: 매 모든 modern LLM. 매 long context. 매 high-batch serving.
|
||
**언제 X**: 매 super-small model (< 1B).
|
||
|
||
## ❌ 안티패턴
|
||
- **MHA in production large LLM**: 매 KV cache OOM.
|
||
- **Wrong N/G ratio**: 매 quality drop.
|
||
- **No GQA-aware kernel**: 매 inefficient.
|
||
- **Convert without retraining check**: 매 quality cliff.
|
||
|
||
## 🧪 검증 / 중복
|
||
- Verified (Ainslie GQA 2023, Llama 2/3 papers, Shazeer MQA 2019).
|
||
- 신뢰도 A.
|
||
|
||
## 🕓 Changelog
|
||
| 날짜 | 변경 |
|
||
|---|---|
|
||
| 2026-05-08 | Phase 1 |
|
||
| 2026-05-10 | Manual cleanup — MHA/MQA/GQA + 매 KV cache calc / Llama / vLLM code |
|