Files
2nd/10_Wiki/Topics/AI_and_ML/Grouped-Query Attention (GQA).md
T
koriweb d8a80f6272 chore(wiki): dangling 링크 canonical 정규화 (768파일/1200건)
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해
끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은
과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업.
도구: Datacollect/scripts/link_reconcile_apply.mjs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 12:24:15 +09:00

8.4 KiB
Raw Blame History

id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
id title category status canonical_id aliases duplicate_of source_trust_level confidence_score verification_status tags raw_sources last_reinforced github_commit tech_stack
wiki-2026-0508-grouped-query-attention-gqa Grouped-Query Attention (GQA) 10_Wiki/Topics verified self
GQA
grouped-query attention
MQA
multi-query attention
KV cache reduction
Llama
none A 0.96 applied
transformer
attention
gqa
mqa
kv-cache
llama
inference-optimization
2026-05-10 pending
language framework
Python PyTorch / vLLM

Grouped-Query Attention (GQA)

매 한 줄

"매 multi-head attention 와 multi-query attention 의 가운데". Ainslie 2023 (Google). 매 Q heads = N, K/V heads = G (G < N). 매 KV cache size ↓ + 매 quality 의 MHA 와 가까움. 매 Llama 2 70B+, Mistral, 모든 modern LLM 의 standard.

매 핵심

매 spectrum

  • MHA: Q=N, K=N, V=N (예: 32/32/32).
  • MQA: Q=N, K=1, V=1 (예: 32/1/1).
  • GQA: Q=N, K=G, V=G (예: 32/8/8).

매 trade-off

  • MHA: 매 best quality, 매 largest KV cache.
  • MQA: 매 smallest cache, 매 quality 매 ↓.
  • GQA: 매 sweet spot.

매 inference impact

  • KV cache = batch × seq_len × n_layers × n_kv_heads × head_dim × 2 (K, V).
  • 매 GQA: 매 N → G 의 의 의 cache 의 N/G 배 reduce.

매 응용

  • Llama 2 70B: 32 Q heads, 8 KV heads.
  • Llama 3: GQA 표준.
  • Mistral, Mixtral: GQA.
  • Gemma, Qwen: GQA.

💻 패턴

MHA (baseline)

import torch
import torch.nn.functional as F

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.q_proj = torch.nn.Linear(dim, dim)
        self.k_proj = torch.nn.Linear(dim, dim)
        self.v_proj = torch.nn.Linear(dim, dim)
        self.o_proj = torch.nn.Linear(dim, dim)
    
    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.o_proj(out.transpose(1, 2).reshape(B, T, C))

GQA

class GQA(torch.nn.Module):
    def __init__(self, dim, n_q_heads, n_kv_heads):
        super().__init__()
        self.n_q_heads = n_q_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_q_heads // n_kv_heads
        self.head_dim = dim // n_q_heads
        
        self.q_proj = torch.nn.Linear(dim, n_q_heads * self.head_dim)
        self.k_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
        self.v_proj = torch.nn.Linear(dim, n_kv_heads * self.head_dim)
        self.o_proj = torch.nn.Linear(dim, dim)
    
    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        
        # 매 repeat KV heads to match Q
        k = k.repeat_interleave(self.n_rep, dim=1)
        v = v.repeat_interleave(self.n_rep, dim=1)
        
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.o_proj(out.transpose(1, 2).reshape(B, T, C))

MQA (extreme)

class MQA(torch.nn.Module):
    """매 1 KV head."""
    def __init__(self, dim, n_q_heads):
        super().__init__()
        self.n_q_heads = n_q_heads
        self.head_dim = dim // n_q_heads
        self.q_proj = torch.nn.Linear(dim, dim)
        self.k_proj = torch.nn.Linear(dim, self.head_dim)  # 매 single head
        self.v_proj = torch.nn.Linear(dim, self.head_dim)
        self.o_proj = torch.nn.Linear(dim, dim)
    
    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
        v = self.v_proj(x).view(B, T, 1, self.head_dim).transpose(1, 2).expand(-1, self.n_q_heads, -1, -1)
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.o_proj(out.transpose(1, 2).reshape(B, T, C))

KV cache size (calculate)

def kv_cache_bytes(batch, seq_len, n_layers, n_kv_heads, head_dim, dtype_bytes=2):
    return batch * seq_len * n_layers * n_kv_heads * head_dim * 2 * dtype_bytes

# 매 Llama 2 70B (MHA): 80 layers, 64 heads, 128 dim
# 매 batch=1, seq=2048
# 매 cache = 1 * 2048 * 80 * 64 * 128 * 2 * 2 = 5.4 GB

# 매 Llama 2 70B (GQA): 8 KV heads
# 매 cache = 1 * 2048 * 80 * 8 * 128 * 2 * 2 = 670 MB (8x ↓)

Convert MHA → GQA (mean)

def mha_to_gqa(k_proj, v_proj, n_q_heads, n_kv_heads):
    """매 mean of N/G heads → single GQA head."""
    n_rep = n_q_heads // n_kv_heads
    head_dim = k_proj.weight.size(0) // n_q_heads
    
    new_k_weight = k_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
    new_v_weight = v_proj.weight.view(n_q_heads, head_dim, -1).reshape(n_kv_heads, n_rep, head_dim, -1).mean(dim=1).reshape(n_kv_heads * head_dim, -1)
    
    return new_k_weight, new_v_weight

Llama-style GQA + RoPE

class LlamaAttention(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_q = config.num_attention_heads
        self.n_kv = config.num_key_value_heads  # 매 GQA
        self.head_dim = config.hidden_size // self.n_q
        self.q_proj = torch.nn.Linear(config.hidden_size, self.n_q * self.head_dim, bias=False)
        self.k_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
        self.v_proj = torch.nn.Linear(config.hidden_size, self.n_kv * self.head_dim, bias=False)
        self.o_proj = torch.nn.Linear(self.n_q * self.head_dim, config.hidden_size, bias=False)
        self.rotary = RoPE(self.head_dim, config.max_position_embeddings)
    
    def forward(self, x, kv_cache=None):
        # 매 q, k, v + RoPE + cache
        ...

vLLM (production GQA serving)

from vllm import LLM
llm = LLM(model='meta-llama/Llama-3.1-70B-Instruct')
# 매 internally uses GQA + paged attention + flash attn

Flash Attention with GQA

from flash_attn import flash_attn_func
# 매 supports GQA natively
out = flash_attn_func(q, k, v, causal=True)  # 매 q heads != kv heads

Benchmark KV cache savings

import time
def benchmark(model, batch_sizes, seq_lens):
    for b in batch_sizes:
        for s in seq_lens:
            try:
                inputs = torch.randint(0, 32000, (b, s)).cuda()
                t0 = time.perf_counter()
                model.generate(inputs, max_new_tokens=128)
                print(f'b={b}, s={s}: {time.perf_counter()-t0:.2f}s, peak={torch.cuda.max_memory_allocated()/1e9:.2f}GB')
            except torch.cuda.OutOfMemoryError:
                print(f'b={b}, s={s}: OOM')
            torch.cuda.empty_cache()

매 결정 기준

상황 Approach
Long context GQA (essential)
High batch GQA / MQA
Quality-critical small MHA
Mobile / edge MQA
Llama-style GQA (8 KV)
Modern default GQA

기본값: 매 modern LLM = GQA (4-8 KV heads). 매 quality > 1% loss = MHA. 매 extreme constraint = MQA.

🔗 Graph

🤖 LLM 활용

언제: 매 모든 modern LLM. 매 long context. 매 high-batch serving. 언제 X: 매 super-small model (< 1B).

안티패턴

  • MHA in production large LLM: 매 KV cache OOM.
  • Wrong N/G ratio: 매 quality drop.
  • No GQA-aware kernel: 매 inefficient.
  • Convert without retraining check: 매 quality cliff.

🧪 검증 / 중복

  • Verified (Ainslie GQA 2023, Llama 2/3 papers, Shazeer MQA 2019).
  • 신뢰도 A.

🕓 Changelog

날짜 변경
2026-05-08 Phase 1
2026-05-10 Manual cleanup — MHA/MQA/GQA + 매 KV cache calc / Llama / vLLM code