Files
2nd/10_Wiki/Topics/AI_and_ML/Sparse Attention.md
T
koriweb d8a80f6272 chore(wiki): dangling 링크 canonical 정규화 (768파일/1200건)
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해
끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은
과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업.
도구: Datacollect/scripts/link_reconcile_apply.mjs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 12:24:15 +09:00

6.4 KiB

id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
id title category status canonical_id aliases duplicate_of source_trust_level confidence_score verification_status tags raw_sources last_reinforced github_commit tech_stack
wiki-2026-0508-sparse-attention Sparse Attention 10_Wiki/Topics verified self
Sparse Self-Attention
Local Attention
Efficient Attention
none A 0.9 applied
transformer
attention
long-context
efficient-llm
sparse
2026-05-10 pending
language framework
Python PyTorch/FlashAttention

Sparse Attention

매 한 줄

"매 O(n²) attention 의 O(n·k) 의 reduce — token pair 의 subset 만 compute". Sparse attention 매 long-context Transformer 의 enabler, sliding-window + global tokens (Longformer) 매 base, BigBird/LongNet 매 random + dilated. 매 2026 매 native sparse (DeepSeek NSA, MoBA) + SSM hybrid (Mamba2) + FlashAttention-3 sparse mask 매 production.

매 핵심

매 sparsity pattern

  • Sliding window: 매 ±w tokens. Local context. (Longformer, Mistral SWA).
  • Global tokens: 매 [CLS] + 특정 token 매 모든 token 의 attend.
  • Dilated: 매 stride k. Long-range w/o full O(n²).
  • Random: 매 random k tokens. BigBird 의 component.
  • Block sparse: 매 block-diagonal + selected blocks. FlashAttention 친화.
  • Learned/adaptive: 매 routing network 의 어디 sparse 의 decide (NSA, MoBA 2025).

매 historical landmarks

  • Sparse Transformer (OpenAI 2019): factorized attention.
  • Longformer (AllenAI 2020): SWA + global. 4k→16k+ tokens.
  • BigBird (Google 2020): random + window + global. Theoretically 의 full-attn approximate.
  • LongNet (Microsoft 2023): dilated → 1B token claim.
  • NSA (DeepSeek 2025): native sparse 매 pretraining.
  • MoBA (Moonshot 2025): mixture-of-block-attention, hierarchical sparsity.

매 응용

  1. Long-document QA / summarization.
  2. Code-base wide LLM analysis (Claude 1M context).
  3. Genomics / DNA Transformer.
  4. Video Transformer (frames as tokens).

💻 패턴

Sliding window mask

import torch

def sliding_window_mask(n: int, w: int, device='cuda') -> torch.Tensor:
    """Boolean mask: True = allow attention."""
    idx = torch.arange(n, device=device)
    diff = idx.unsqueeze(0) - idx.unsqueeze(1)
    return diff.abs() <= w

Longformer-style (SWA + global)

def longformer_mask(n: int, w: int, global_idx: list[int], device='cuda'):
    mask = sliding_window_mask(n, w, device)
    g = torch.zeros(n, dtype=torch.bool, device=device)
    g[global_idx] = True
    # global attends to all + all attend to global
    mask = mask | g.unsqueeze(0) | g.unsqueeze(1)
    return mask

BigBird random component

def random_attention_mask(n: int, k: int, device='cuda') -> torch.Tensor:
    """Each token attends to k random others."""
    mask = torch.zeros(n, n, dtype=torch.bool, device=device)
    for i in range(n):
        idx = torch.randperm(n, device=device)[:k]
        mask[i, idx] = True
    return mask

FlashAttention-2/3 with custom mask

from flash_attn import flash_attn_func
# (b, s, h, d) — fp16/bf16
# FA-3 supports block-sparse mask via mask_mod (PyTorch 2.5+ FlexAttention)
out = flash_attn_func(q, k, v, causal=True, window_size=(512, 0))
# window_size=(left, right) — Mistral-style SWA

FlexAttention (PyTorch 2.5+)

from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def sliding_window(b, h, q_idx, kv_idx):
    return (q_idx - kv_idx).abs() <= 512

block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=8192, KV_LEN=8192)
out = flex_attention(q, k, v, block_mask=block_mask)

Block-sparse (DeepSeek NSA pseudo)

def block_sparse_attn(q, k, v, block_size=64, top_k_blocks=8):
    # 1. Compute block-level importance via mean-pooled K
    n_blocks = k.shape[1] // block_size
    k_blocks = k.view(*k.shape[:1], n_blocks, block_size, *k.shape[2:]).mean(dim=2)
    scores = torch.einsum('bnhd,bmhd->bnmh', q, k_blocks)
    # 2. Select top-k blocks per query
    _, top_idx = scores.topk(top_k_blocks, dim=2)
    # 3. Gather + dense attn within
    return _gather_and_attend(q, k, v, top_idx, block_size)

Mistral SWA (HuggingFace)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    sliding_window=4096,
)

Adaptive top-k token (Native Sparse)

def topk_attention(q, k, v, top_k=128):
    # (b, h, s, d)
    scores = q @ k.transpose(-2, -1) / q.shape[-1]**0.5
    # Keep top_k per query
    top_v, top_i = scores.topk(top_k, dim=-1)
    sparse = torch.full_like(scores, float('-inf'))
    sparse.scatter_(-1, top_i, top_v)
    attn = sparse.softmax(dim=-1)
    return attn @ v

매 결정 기준

상황 Approach
4-32k context, local-mostly Sliding window (Mistral SWA)
Long-doc QA w/ key-token Longformer (SWA + global)
100k+ context, hardware-friendly Block-sparse + FlashAttention
Native long-context pretraining NSA / MoBA (2025+)
Inference-only swap Top-k token sparsification

기본값: 매 inference 매 SWA + FlashAttention; 매 pretraining 매 native sparse (NSA-like).

🔗 Graph

🤖 LLM 활용

언제: 매 attention pattern selection rationale, 매 mask code draft, 매 paper distillation (NSA/MoBA). 언제 X: 매 production kernel write (use FA-3 / FlexAttention), 매 perf measurement (real benchmark).

안티패턴

  • Naïve mask + softmax: 매 O(n²) memory still. 매 -inf masking 매 helps compute 의 X.
  • Random sparsity only: 매 quality drop 매 catastrophic. 매 hybrid (window + global) needed.
  • Fixed window for all heads: 매 head 마다 different need. 매 per-head adaptive 의 better.
  • No global token: 매 long-doc QA 매 [CLS]/question token 의 全文 access 의 lose.
  • Window too small: 매 perplexity 매 baseline 의 break.

🧪 검증 / 중복

  • Verified (Longformer arXiv:2004.05150; BigBird arXiv:2007.14062; FlashAttention-3 2024; DeepSeek NSA 2025; PyTorch FlexAttention).
  • 신뢰도 A.

🕓 Changelog

날짜 변경
2026-05-08 Phase 1
2026-05-10 Manual cleanup — full content (SWA/BigBird/NSA + FlashAttention/FlexAttention patterns)