--- id: wiki-2026-0508-nlp-attention-mechanisms title: NLP Attention Mechanisms category: 10_Wiki/Topics status: verified canonical_id: self aliases: [Attention, Self-Attention, Multi-Head Attention, Bahdanau, Luong, Flash Attention] duplicate_of: none source_trust_level: A confidence_score: 0.95 verification_status: applied tags: [nlp, attention, transformer, deep-learning, flash-attention] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: { language: python, framework: pytorch } --- # NLP Attention Mechanisms ## 매 한 줄 **Attention**은 시퀀스 내 토큰 간 가중 의존성을 동적으로 학습하는 메커니즘으로, Bahdanau(2014) additive → Luong multiplicative → scaled dot-product → multi-head → Flash Attention 진화를 거쳐 모든 현대 LLM의 코어가 되었다. ## 매 핵심 ### 1. Attention 일반 공식 ``` attention(Q, K, V) = softmax(score(Q, K)) · V ``` - score 함수가 변형의 핵심: additive vs multiplicative vs dot-product. - 출력은 V의 가중합, 가중치는 Q-K 유사도. ### 2. Bahdanau (Additive, 2014) ``` score(q, k) = vᵀ tanh(W_q q + W_k k) ``` - MLP 기반 — 학습 파라미터 많음. - 작은 차원에서 더 풍부한 매칭. - seq2seq 번역 (RNN encoder-decoder) 출시. ### 3. Luong (Multiplicative, 2015) ``` score(q, k) = qᵀ W k (general) score(q, k) = qᵀ k (dot) ``` - 행렬 곱 1번 — 빠름. - GPU 친화적. ### 4. Scaled Dot-Product (Vaswani 2017, Transformer) ``` Attention(Q,K,V) = softmax(QKᵀ / √d_k) V ``` - √d_k로 나눠 큰 차원 softmax saturation 방지. - 행렬 연산 — 병렬 GPU 최적. - 이게 Transformer의 핵심. ### 5. Multi-Head Attention ``` MHA(Q,K,V) = Concat(head_1, ..., head_h) W^O head_i = Attention(Q W^Q_i, K W^K_i, V W^V_i) ``` - h개의 head가 서로 다른 sub-space에서 attention. - 표현력 ↑ — 한 head는 syntactic, 다른 head는 semantic. - 표준 h = 8, 16, 32, 64 (모델 크기 의존). ### 6. Self vs Cross Attention - **Self**: Q=K=V (같은 시퀀스) — encoder, decoder masked. - **Cross**: Q from decoder, K=V from encoder — encoder-decoder bridge. ### 7. Causal / Masked Attention - Decoder에서 미래 토큰 참조 차단 (-inf 마스크). - LLM autoregressive 생성 표준. ### 8. Positional Encoding - Attention은 순서 무인지 → 위치 정보 추가 필요. - **Sinusoidal** (원조), **Learned**, **RoPE** (Rotary, LLaMA/현대 LLM 표준), **ALiBi**. ### 9. Modern: Flash Attention (Dao 2022, FA2 2023, FA3 2024) - IO-aware 알고리즘: GPU SRAM 활용해 HBM 왕복 최소화. - 정확한 attention (근사 X) — 2-4× 빠름, 메모리 5-20× 절감. - 긴 컨텍스트(100K-1M token)의 게임 체인저. - FA2: warp-level 병렬화. FA3 (Hopper): WGMMA + async. ### 10. 효율 변형 - **MQA (Multi-Query Attention)**: KV 헤드 1개 — 추론 빠름. - **GQA (Grouped-Query)**: KV 헤드 그룹화 — LLaMA-2/3 표준. - **Sliding Window**: local attention (Mistral). - **Sparse / Linear / Linformer / Performer**: O(n²) → O(n log n) 또는 O(n). - **Ring Attention**: 분산 long-context (Gemini 2M). ## 💻 패턴 ```python # 1. Bahdanau additive (PyTorch) import torch, torch.nn as nn class BahdanauAttention(nn.Module): def __init__(self, d): super().__init__() self.W_q = nn.Linear(d, d); self.W_k = nn.Linear(d, d); self.v = nn.Linear(d, 1) def forward(self, q, k, v): # q: (B, 1, d), k/v: (B, T, d) score = self.v(torch.tanh(self.W_q(q) + self.W_k(k))).squeeze(-1) # (B, T) a = torch.softmax(score, dim=-1) return (a.unsqueeze(-1) * v).sum(dim=1) ``` ```python # 2. Scaled dot-product def scaled_dot_product(Q, K, V, mask=None): d_k = Q.size(-1) scores = (Q @ K.transpose(-2, -1)) / d_k**0.5 if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) return torch.softmax(scores, dim=-1) @ V ``` ```python # 3. Multi-head (from scratch) class MHA(nn.Module): def __init__(self, d, h): super().__init__() self.h, self.dh = h, d // h self.qkv = nn.Linear(d, 3*d); self.o = nn.Linear(d, d) def forward(self, x, mask=None): B, T, D = x.shape q,k,v = self.qkv(x).chunk(3, dim=-1) q,k,v = [t.view(B,T,self.h,self.dh).transpose(1,2) for t in (q,k,v)] out = scaled_dot_product(q,k,v,mask).transpose(1,2).reshape(B,T,D) return self.o(out) ``` ```python # 4. PyTorch built-in mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True) out, attn_weights = mha(x, x, x) # self-attention ``` ```python # 5. Causal mask (decoder) T = 1024 mask = torch.tril(torch.ones(T, T)).bool() # lower-tri: 1 keep, 0 mask ``` ```python # 6. Flash Attention (xformers / pytorch SDPA backend) import torch.nn.functional as F out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # PyTorch 2.0+ 자동 Flash Attention backend ``` ```python # 7. RoPE (Rotary Position Embedding) def rope(x, freqs): # x: (..., d), freqs: (T, d/2) x1, x2 = x.chunk(2, dim=-1) cos, sin = freqs.cos(), freqs.sin() return torch.cat([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1) ``` ```python # 8. GQA (Grouped-Query) class GQA(nn.Module): def __init__(self, d, n_q_heads, n_kv_heads): super().__init__() self.n_q, self.n_kv = n_q_heads, n_kv_heads self.dh = d // n_q_heads self.q = nn.Linear(d, n_q_heads * self.dh) self.k = nn.Linear(d, n_kv_heads * self.dh) self.v = nn.Linear(d, n_kv_heads * self.dh) # KV broadcast n_q / n_kv 배 반복하여 attention ``` ```python # 9. Sliding window (Mistral-style) def sliding_window_mask(T, window=4096): m = torch.tril(torch.ones(T, T)) m = m - torch.tril(torch.ones(T, T), diagonal=-window-1) return m.bool() ``` ```python # 10. Visualizing attention import matplotlib.pyplot as plt attn = mha(x, x, x, need_weights=True, average_attn_weights=False)[1] # (B, h, T, T) plt.imshow(attn[0, 0].cpu().numpy()) # head 0 ``` ## 매 결정 기준 | 상황 | 추천 | |------|------| | Transformer 표준 | **Scaled dot-product MHA** | | 긴 컨텍스트 (>32K) | **Flash Attention 2/3** | | 추론 속도 (LLM) | **GQA / MQA + KV cache** | | Local 패턴 충분 | **Sliding window (Mistral)** | | Encoder-decoder 번역 | **Cross attention** | | 작은 모델 + 작은 d | Bahdanau additive (legacy) | | 위치 표현 | **RoPE** (modern) / ALiBi (long ctx) | | 1M+ 컨텍스트 분산 | **Ring Attention** | ## 🔗 Graph - 부모: [[Transformer_Architecture_and_LLM_Foundations|Transformer Architecture]] - 변형: [[Multi-Head Attention]], [[Flash Attention]], [[Grouped-Query Attention]] - 응용: [[Transformer_Architecture_and_LLM_Foundations|LLM]], [[Transformer_Architecture_and_LLM_Foundations|BERT]], [[GPT]] - Adjacent: [[KV Cache]], [[Long Context]] ## 🤖 LLM 활용 - "이 attention map을 보고 모델이 어떤 토큰에 의존하는지 분석" — interpretability. - 코드에 SDPA / FlashAttention 적용 자동 리팩토링. - Attention 변형 비교표 생성, ablation 가이드. ## ❌ 안티패턴 - **√d_k 정규화 누락**: 큰 d에서 softmax saturation → gradient 소실. - **Causal mask 없는 decoder**: 미래 leak → 학습/추론 불일치. - **벡터화 안 한 attention 루프**: 100배 느림. - **MHA 추론 + KV cache 없음**: 긴 생성에서 O(n²) 재계산. - **Vanilla attention + 100K context**: OOM — Flash Attention 필수. - **Position encoding 누락**: bag-of-words처럼 동작. ## 🧪 검증 / 중복 - 검증: Vaswani(2017) "Attention is All You Need", Bahdanau(2014), Dao(2022 FlashAttention), HuggingFace docs. - 중복: [[Multi-Head Attention]], [[Flash Attention]] (specific) — 본 문서는 family overview. ## 🕓 Changelog - 2026-05-10: 신규 작성. Bahdanau→Luong→SDPA→MHA→Flash→GQA/MQA/Sliding 진화 + 코드 패턴.