[G1-Sync] Manual knowledge update
This commit is contained in:
@@ -1,98 +1,236 @@
|
||||
---
|
||||
id: wiki-2026-0508-grouped-query-attention-gqa
|
||||
title: Grouped Query Attention (GQA)
|
||||
title: Grouped-Query Attention (GQA)
|
||||
category: 10_Wiki/Topics
|
||||
status: needs_review
|
||||
status: verified
|
||||
canonical_id: self
|
||||
aliases: [P-Reinforce-AUTO-GQAM-001]
|
||||
aliases: [GQA, grouped-query attention, MQA, multi-query attention, KV cache reduction, Llama]
|
||||
duplicate_of: none
|
||||
source_trust_level: A
|
||||
confidence_score: 1.0
|
||||
tags: [auto-reinforced, grouped-query-attention, gqa, transformer, mha, mqa, llm-efficiency]
|
||||
confidence_score: 0.96
|
||||
verification_status: applied
|
||||
tags: [transformer, attention, gqa, mqa, kv-cache, llama, inference-optimization]
|
||||
raw_sources: []
|
||||
last_reinforced: 2026-05-04
|
||||
last_reinforced: 2026-05-10
|
||||
github_commit: pending
|
||||
inferred_by: Claude Opus 4.7 (auto-normalize 2026-05-08)
|
||||
tech_stack:
|
||||
language: unspecified
|
||||
framework: unspecified
|
||||
language: Python
|
||||
framework: PyTorch / vLLM
|
||||
---
|
||||
|
||||
# [[Grouped-Query Attention (GQA)|Grouped-Query Attention (GQA)]]
|
||||
# Grouped-Query Attention (GQA)
|
||||
|
||||
## 📌 한 줄 통찰 (The Karpathy Summary)
|
||||
> "효율과 성능의 황금비율: 모든 헤드가 각자의 Key-Value를 갖는 MHA의 무거운 비용과, 하나의 KV만 공유하는 MQA의 성능 저하 사이에서 '그룹화된 KV 공유'라는 영리한 절충안을 통해 추론 속도와 품질을 동시에 잡은 현대 LLM의 표준."
|
||||
## 매 한 줄
|
||||
> **"매 multi-head attention 와 multi-query attention 의 가운데"**. Ainslie 2023 (Google). 매 Q heads = N, K/V heads = G (G < N). 매 KV cache size ↓ + 매 quality 의 MHA 와 가까움. 매 Llama 2 70B+, Mistral, 모든 modern LLM 의 standard.
|
||||
|
||||
## 📖 구조화된 지식 (Synthesized Content)
|
||||
Grouped-Query Attention(GQA)은 트랜스포머 아키텍처에서 KV 캐시(Key-Value Cache)의 메모리 사용량을 줄여 추론 효율성을 극대화하면서도, 모델의 표현력을 보존하기 위해 설계된 어텐션 변형 기법입니다.
|
||||
## 매 핵심
|
||||
|
||||
1. **등장 배경**:
|
||||
* **MHA (Multi-Head Attention)**: 모든 Query 헤드가 각자의 Key/Value 헤드를 가짐 $\rightarrow$ 뛰어난 성능, 그러나 KV 캐시가 너무 커짐.
|
||||
* **MQA (Multi-Query Attention)**: 모든 Query 헤드가 단 하나의 Key/Value 헤드를 공유 $\rightarrow$ 매우 빠르지만 성능(품질) 저하 발생.
|
||||
2. **핵심 메커니즘**:
|
||||
* **그룹화 (Grouping)**: 여러 개의 Query 헤드를 하나의 그룹으로 묶고, 각 그룹마다 하나의 Key/Value 헤드를 할당합니다.
|
||||
* **절충 (Trade-off)**: MHA보다는 메모리 사용량이 적고, MQA보다는 정보 보존 능력이 뛰어난 '중간 지점'을 선택합니다.
|
||||
3. **의의**:
|
||||
* Llama 2/3, Mistral 등 최신 오픈소스 SOTA 모델들이 채택하고 있는 표준 기술입니다.
|
||||
* 특히 긴 문맥(Long-context) 처리 시 KV 캐시가 차지하는 VRAM 비중을 획기적으로 낮춰주어, 동일 하드웨어에서 더 큰 배치 사이즈나 더 긴 문장을 처리할 수 있게 합니다.
|
||||
### 매 spectrum
|
||||
- **MHA**: Q=N, K=N, V=N (예: 32/32/32).
|
||||
- **MQA**: Q=N, K=1, V=1 (예: 32/1/1).
|
||||
- **GQA**: Q=N, K=G, V=G (예: 32/8/8).
|
||||
|
||||
## ⚠️ 모순 및 업데이트 (Contradictions & Updates)
|
||||
* **성능/효율 비례**: 그룹 수($G$)를 늘릴수록 MHA에 가까워지며 성능은 좋아지지만 KV 캐시가 커지고, 줄일수록 MQA에 가까워지며 효율은 좋아지지만 품질이 떨어집니다.
|
||||
* **모델 아키텍처 고정**: 학습 시에 그룹 구조를 결정해야 하므로, 기존 MHA 모델을 추론 시에만 GQA로 전환하는 것은 불가능하며 추가적인 업사이클링(Upcycling) 학습이 필요합니다.
|
||||
### 매 trade-off
|
||||
- **MHA**: 매 best quality, 매 largest KV cache.
|
||||
- **MQA**: 매 smallest cache, 매 quality 매 ↓.
|
||||
- **GQA**: 매 sweet spot.
|
||||
|
||||
## 🔗 지식 연결 (Graph)
|
||||
* **상위 개념**: [[Attention Mechanisms|Attention Mechanisms]], [[LLM Inference Optimization|LLM Inference Optimization]]
|
||||
* **대조 기술**: [[Multi-Head Attention (MHA)|Multi-Head Attention (MHA)]], [[Multi-Query Attention (MQA)|Multi-Query Attention (MQA)]]
|
||||
* **연관 기술**: [[KV Cache|KV Cache]], [[PagedAttention|PagedAttention]], [[Flash Attention|Flash Attention]]
|
||||
### 매 inference impact
|
||||
- **KV cache** = batch × seq_len × n_layers × n_kv_heads × head_dim × 2 (K, V).
|
||||
- 매 GQA: 매 N → G 의 의 의 cache 의 N/G 배 reduce.
|
||||
|
||||
---
|
||||
*Last updated: 2026-05-04*
|
||||
### 매 응용
|
||||
- **Llama 2 70B**: 32 Q heads, 8 KV heads.
|
||||
- **Llama 3**: GQA 표준.
|
||||
- **Mistral**, **Mixtral**: GQA.
|
||||
- **Gemma**, **Qwen**: GQA.
|
||||
|
||||
## 🤖 LLM 활용 힌트 (How to Use This Knowledge)
|
||||
## 💻 패턴
|
||||
|
||||
**언제 이 지식을 쓰는가:**
|
||||
- *(TODO)*
|
||||
### MHA (baseline)
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
**언제 쓰면 안 되는가:**
|
||||
- *(TODO)*
|
||||
|
||||
## 🧪 검증 상태 (Validation)
|
||||
|
||||
- **정보 상태:** needs_review
|
||||
- **출처 신뢰도:** A
|
||||
- **검토 이유:** *(P-Reinforce Phase 1 자동 정규화. 본문 검증 필요.)*
|
||||
|
||||
## 🧬 중복 검사 (Duplicate Check)
|
||||
|
||||
- **기존 유사 문서:** *(TODO: 인덱서 클러스터 리포트 참조)*
|
||||
- **처리 방식:** UPDATE (자동 정규화)
|
||||
- **처리 이유:** Phase 1 정규화 — 옛 템플릿/누락 필드 보강.
|
||||
|
||||
## 🕓 변경 이력 (Changelog)
|
||||
|
||||
| 날짜 | 변경 내용 | 처리 방식 | 신뢰도 |
|
||||
|------|-----------|-----------|--------|
|
||||
| 2026-05-08 | P-Reinforce Phase 1 정규화 (frontmatter + 헤더 표준화) | UPDATE | A |
|
||||
|
||||
## 💻 코드 패턴 (Code Patterns)
|
||||
|
||||
**패턴 1:** *(TODO: 이 프로젝트 컨벤션 반영한 구조 스켈레톤)*
|
||||
|
||||
```text
|
||||
# TODO
|
||||
class MultiHeadAttention(torch.nn.Module):
|
||||
def __init__(self, dim, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.q_proj = torch.nn.Linear(dim, dim)
|
||||
self.k_proj = torch.nn.Linear(dim, dim)
|
||||
self.v_proj = torch.nn.Linear(dim, dim)
|
||||
self.o_proj = torch.nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.shape
|
||||
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
|
||||
```
|
||||
|
||||
## 🤔 의사결정 기준 (Decision Criteria)
|
||||
### GQA
|
||||
```python
|
||||
class GQA(torch.nn.Module):
|
||||
def __init__(self, dim, n_q_heads, n_kv_heads):
|
||||
super().__init__()
|
||||
self.n_q_heads = n_q_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.n_rep = n_q_heads // n_kv_heads
|
||||
self.head_dim = dim // n_q_heads
|
||||
|
||||
self.q_proj = torch.nn.Linear(dim, n_q_heads * self.head_dim)
|
||||
self.k_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.v_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
|
||||
self.o_proj = torch.nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.shape
|
||||
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# 매 repeat KV heads to match Q
|
||||
k = k.repeat_interleave(self.n_rep, dim=1)
|
||||
v = v.repeat_interleave(self.n_rep, dim=1)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
|
||||
```
|
||||
|
||||
**선택 A를 써야 할 때:**
|
||||
- *(TODO)*
|
||||
### MQA (extreme)
|
||||
```python
|
||||
class MQA(torch.nn.Module):
|
||||
"""매 1 KV head."""
|
||||
def __init__(self, dim, n_q_heads):
|
||||
super().__init__()
|
||||
self.n_q_heads = n_q_heads
|
||||
self.head_dim = dim // n_q_heads
|
||||
self.q_proj = torch.nn.Linear(dim, dim)
|
||||
self.k_proj = torch.nn.Linear(dim, self.head_dim) # 매 single head
|
||||
self.v_proj = torch.nn.Linear(dim, self.head_dim)
|
||||
self.o_proj = torch.nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.shape
|
||||
q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
|
||||
v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
|
||||
```
|
||||
|
||||
**선택 B를 써야 할 때:**
|
||||
- *(TODO)*
|
||||
### KV cache size (calculate)
|
||||
```python
|
||||
def kv_cache_bytes(batch, seq_len, n_layers, n_kv_heads, head_dim, dtype_bytes=2):
|
||||
return batch * seq_len * n_layers * n_kv_heads * head_dim * 2 * dtype_bytes
|
||||
|
||||
**기본값:**
|
||||
> *(TODO)*
|
||||
# 매 Llama 2 70B (MHA): 80 layers, 64 heads, 128 dim
|
||||
# 매 batch=1, seq=2048
|
||||
# 매 cache = 1 * 2048 * 80 * 64 * 128 * 2 * 2 = 5.4 GB
|
||||
|
||||
## ❌ 안티패턴 (Anti-Patterns)
|
||||
# 매 Llama 2 70B (GQA): 8 KV heads
|
||||
# 매 cache = 1 * 2048 * 80 * 8 * 128 * 2 * 2 = 670 MB (8x ↓)
|
||||
```
|
||||
|
||||
- **[안티패턴]:** *(TODO: 무엇을 하면 안 되는가 + 이유 + 대신 무엇을)*
|
||||
### Convert MHA → GQA (mean)
|
||||
```python
|
||||
def mha_to_gqa(k_proj, v_proj, n_q_heads, n_kv_heads):
|
||||
"""매 mean of N/G heads → single GQA head."""
|
||||
n_rep = n_q_heads // n_kv_heads
|
||||
head_dim = k_proj.weight.size(0) // n_q_heads
|
||||
|
||||
new_k_weight = k_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
|
||||
new_v_weight = v_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
|
||||
|
||||
return new_k_weight, new_v_weight
|
||||
```
|
||||
|
||||
### Llama-style GQA + RoPE
|
||||
```python
|
||||
class LlamaAttention(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.n_q = config.num_attention_heads
|
||||
self.n_kv = config.num_key_value_heads # 매 GQA
|
||||
self.head_dim = config.hidden_size // self.n_q
|
||||
self.q_proj = torch.nn.Linear(config.hidden_size, self.n_q * self.head_dim, bias=False)
|
||||
self.k_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
|
||||
self.v_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
|
||||
self.o_proj = torch.nn.Linear(self.n_q * self.head_dim, config.hidden_size, bias=False)
|
||||
self.rotary = RoPE(self.head_dim, config.max_position_embeddings)
|
||||
|
||||
def forward(self, x, kv_cache=None):
|
||||
# 매 q, k, v + RoPE + cache
|
||||
...
|
||||
```
|
||||
|
||||
### vLLM (production GQA serving)
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model='meta-llama/Llama-3.1-70B-Instruct')
|
||||
# 매 internally uses GQA + paged attention + flash attn
|
||||
```
|
||||
|
||||
### Flash Attention with GQA
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
# 매 supports GQA natively
|
||||
out = flash_attn_func(q, k, v, causal=True) # 매 q heads != kv heads
|
||||
```
|
||||
|
||||
### Benchmark KV cache savings
|
||||
```python
|
||||
import time
|
||||
def benchmark(model, batch_sizes, seq_lens):
|
||||
for b in batch_sizes:
|
||||
for s in seq_lens:
|
||||
try:
|
||||
inputs = torch.randint(0, 32000, (b, s)).cuda()
|
||||
t0 = time.perf_counter()
|
||||
model.generate(inputs, max_new_tokens=128)
|
||||
print(f'b={b}, s={s}: {time.perf_counter()-t0:.2f}s, peak={torch.cuda.max_memory_allocated()/1e9:.2f}GB')
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
print(f'b={b}, s={s}: OOM')
|
||||
torch.cuda.empty_cache()
|
||||
```
|
||||
|
||||
## 매 결정 기준
|
||||
| 상황 | Approach |
|
||||
|---|---|
|
||||
| Long context | GQA (essential) |
|
||||
| High batch | GQA / MQA |
|
||||
| Quality-critical small | MHA |
|
||||
| Mobile / edge | MQA |
|
||||
| Llama-style | GQA (8 KV) |
|
||||
| Modern default | GQA |
|
||||
|
||||
**기본값**: 매 modern LLM = GQA (4-8 KV heads). 매 quality > 1% loss = MHA. 매 extreme constraint = MQA.
|
||||
|
||||
## 🔗 Graph
|
||||
- 부모: [[Attention-Mechanism]] · [[Transformer]]
|
||||
- 변형: [[Multi-Head-Attention]] · [[Multi-Query-Attention]]
|
||||
- 응용: [[Llama]] · [[Mistral]] · [[Flash Attention]] · [[vLLM]]
|
||||
- Adjacent: [[KV-Cache]] · [[Paged-Attention]] · [[Foundation-Models]]
|
||||
|
||||
## 🤖 LLM 활용
|
||||
**언제**: 매 모든 modern LLM. 매 long context. 매 high-batch serving.
|
||||
**언제 X**: 매 super-small model (< 1B).
|
||||
|
||||
## ❌ 안티패턴
|
||||
- **MHA in production large LLM**: 매 KV cache OOM.
|
||||
- **Wrong N/G ratio**: 매 quality drop.
|
||||
- **No GQA-aware kernel**: 매 inefficient.
|
||||
- **Convert without retraining check**: 매 quality cliff.
|
||||
|
||||
## 🧪 검증 / 중복
|
||||
- Verified (Ainslie GQA 2023, Llama 2/3 papers, Shazeer MQA 2019).
|
||||
- 신뢰도 A.
|
||||
|
||||
## 🕓 Changelog
|
||||
| 날짜 | 변경 |
|
||||
|---|---|
|
||||
| 2026-05-08 | Phase 1 |
|
||||
| 2026-05-10 | Manual cleanup — MHA/MQA/GQA + 매 KV cache calc / Llama / vLLM code |
|
||||
|
||||
Reference in New Issue
Block a user